File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007
0008
0009 scaler = joblib.load("model_weight/scaler.pkl")
0010
0011
0012 dataset = TrackCaloDataset("../test500k.list", scaler=scaler)
0013 loader = torch.utils.data.DataLoader(dataset, batch_size=128)
0014 device = "cuda" if torch.cuda.is_available() else "cpu"
0015
0016
0017 model = TrackCaloRegressor()
0018 model.load_state_dict(torch.load("model_weight/best_model.pt"))
0019 model.to(device)
0020 model.eval()
0021
0022
0023 pt_bins = [(0, 3), (3, 6), (6, 10), (10, 15)]
0024 hists = {}
0025 for (ptmin, ptmax) in pt_bins:
0026 name = f"h_relerr_{ptmin}_{ptmax}"
0027 title = f"RelErr for {ptmin} < pt < {ptmax};(pred - truth)/truth;Counts"
0028 hists[(ptmin, ptmax)] = ROOT.TH1D(name, title, 200, -1, 1)
0029
0030
0031 with torch.no_grad():
0032 for x, y in loader:
0033 x = x.to(device)
0034 y = y.to(device)
0035 pred = model(x)
0036
0037 y_np = y.cpu().numpy().flatten()
0038 pred_np = pred.cpu().numpy().flatten()
0039 rel_err = (pred_np - y_np) / (y_np + 1e-6)
0040
0041 for yi, rei in zip(y_np, rel_err):
0042 for (ptmin, ptmax) in pt_bins:
0043 if ptmin <= yi < ptmax:
0044 hists[(ptmin, ptmax)].Fill(rei)
0045 break
0046
0047
0048 out_file = ROOT.TFile("outputFile/pt_relative_error_bybin.root", "RECREATE")
0049 for (ptmin, ptmax), hist in hists.items():
0050 hist.Fit("gaus", "", "", -0.1, 0.1)
0051 hist.Write()
0052 out_file.Close()
0053
0054 print("✅ 分 pt 区间的相对误差直方图已保存到 pt_relative_error_bybin.root")