File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007 from test_func import visualize_model
0008
0009 device = "cuda" if torch.cuda.is_available() else "cpu"
0010
0011 pt_bins = [(0, 10)]
0012
0013 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI.root", "RECREATE")
0014
0015 name_010 = f"h2_pt_vs_relerr_all"
0016 title_010 = f"pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth"
0017 hist_010 = ROOT.TH2D(name_010, title_010, 100, 0, 10, 250, -2., 2.)
0018
0019 for pt_min, pt_max in pt_bins:
0020 print(f"\n=== Testing model for pt in [{pt_min}, {pt_max}) GeV ===")
0021
0022
0023
0024
0025 dataset = TrackCaloDataset("../test500k.list", scaler=None, pt_min=pt_min, pt_max=pt_max)
0026 loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0027
0028
0029 model = TrackCaloRegressor()
0030 model.load_state_dict(torch.load(f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt"))
0031
0032 model.to(device)
0033 model.eval()
0034
0035
0036
0037
0038 name = f"h_relerr_{pt_min}_{pt_max}"
0039 title = f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts"
0040 hist = ROOT.TH1D(name, title, 200, -1, 1)
0041
0042
0043 name2d = f"h2_pt_vs_relerr_{pt_min}_{pt_max}"
0044 title2d = f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth"
0045 nbins_2dy = (pt_max - pt_min) * 10
0046 hist2d = ROOT.TH2D(name2d, title2d, int(nbins_2dy), float(pt_min), float(pt_max), 250, -2., 2.)
0047
0048
0049 with torch.no_grad():
0050 for x, y in loader:
0051 x = x.to(device)
0052 y = y.to(device)
0053 pred = model(x)
0054
0055 y_np = y.cpu().numpy().flatten()
0056 pred_np = pred.cpu().numpy().flatten()
0057 rel_err = (pred_np - y_np) / (y_np)
0058
0059 for yi, rei in zip(y_np, rel_err):
0060 hist.Fill(rei)
0061 hist2d.Fill(yi, rei)
0062 hist_010.Fill(yi, rei)
0063
0064
0065 out_file.cd()
0066 hist.Write()
0067 hist2d.Write()
0068
0069 out_file.cd()
0070 hist_010.Write()
0071
0072 out_file.Close()
0073 print("✅ 每个 pt 区间的相对误差已保存为 pt_relative_error_bybin.root")