File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007
0008 device = "cuda" if torch.cuda.is_available() else "cpu"
0009
0010 pt_bins = [(0, 10)]
0011
0012 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI_ontrain.root", "RECREATE")
0013
0014 name_010 = f"h2_pt_vs_relerr_all"
0015 title_010 = f"pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth"
0016 hist_010 = ROOT.TH2D(name_010, title_010, 100, 0, 10, 250, -2., 2.)
0017
0018 for pt_min, pt_max in pt_bins:
0019 print(f"\n=== Testing model for pt in [{pt_min}, {pt_max}) GeV ===")
0020
0021
0022 scaler = joblib.load(f"model_weight/scaler_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pkl")
0023
0024 dataset = TrackCaloDataset("../train500k.list", scaler=scaler, pt_min=pt_min, pt_max=pt_max)
0025 loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0026
0027
0028 model = TrackCaloRegressor()
0029 model.load_state_dict(torch.load(f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt"))
0030
0031 model.to(device)
0032 model.eval()
0033
0034
0035 name = f"h_relerr_{pt_min}_{pt_max}"
0036 title = f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts"
0037 hist = ROOT.TH1D(name, title, 200, -1, 1)
0038
0039
0040 name2d = f"h2_pt_vs_relerr_{pt_min}_{pt_max}"
0041 title2d = f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth"
0042 nbins_2dy = (pt_max - pt_min) * 10
0043 hist2d = ROOT.TH2D(name2d, title2d, int(nbins_2dy), float(pt_min), float(pt_max), 250, -2., 2.)
0044
0045
0046 with torch.no_grad():
0047 for x, y in loader:
0048 x = x.to(device)
0049 y = y.to(device)
0050 pred = model(x)
0051
0052 y_np = y.cpu().numpy().flatten()
0053 pred_np = pred.cpu().numpy().flatten()
0054 rel_err = (pred_np - y_np) / (y_np)
0055
0056 for yi, rei in zip(y_np, rel_err):
0057 hist.Fill(rei)
0058 hist2d.Fill(yi, rei)
0059 hist_010.Fill(yi, rei)
0060
0061 hist.Fit("gaus", "", "", -0.1, 0.1)
0062 hist.Write()
0063 hist2d.Write()
0064
0065 hist_010.Write()
0066
0067 out_file.Close()
0068 print("✅ 每个 pt 区间的相对误差已保存为 pt_relative_error_bybin.root")