Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007 
0008 # ==== 加载 scaler ====
0009 scaler = joblib.load("model_weight/scaler.pkl")
0010 
0011 # ==== 加载测试数据 ====
0012 dataset = TrackCaloDataset("../test500k.list", scaler=scaler)
0013 loader = torch.utils.data.DataLoader(dataset, batch_size=128)
0014 device = "cuda" if torch.cuda.is_available() else "cpu"
0015 
0016 # ==== 加载模型 ====
0017 model = TrackCaloRegressor()
0018 model.load_state_dict(torch.load("model_weight/best_model.pt"))
0019 model.to(device)
0020 model.eval()
0021 
0022 # ==== 设置 pt 区间划分 ====
0023 pt_bins = [(0, 3), (3, 6), (6, 10), (10, 15)]
0024 hists = {}
0025 for (ptmin, ptmax) in pt_bins:
0026     name = f"h_relerr_{ptmin}_{ptmax}"
0027     title = f"RelErr for {ptmin} < pt < {ptmax};(pred - truth)/truth;Counts"
0028     hists[(ptmin, ptmax)] = ROOT.TH1D(name, title, 200, -1, 1)
0029 
0030 # ==== 推理并按区间填充 ====
0031 with torch.no_grad():
0032     for x, y in loader:
0033         x = x.to(device)
0034         y = y.to(device)
0035         pred = model(x)
0036 
0037         y_np = y.cpu().numpy().flatten()
0038         pred_np = pred.cpu().numpy().flatten()
0039         rel_err = (pred_np - y_np) / (y_np + 1e-6)
0040 
0041         for yi, rei in zip(y_np, rel_err):
0042             for (ptmin, ptmax) in pt_bins:
0043                 if ptmin <= yi < ptmax:
0044                     hists[(ptmin, ptmax)].Fill(rei)
0045                     break
0046 
0047 # ==== 拟合 & 保存 ====
0048 out_file = ROOT.TFile("outputFile/pt_relative_error_bybin.root", "RECREATE")
0049 for (ptmin, ptmax), hist in hists.items():
0050     hist.Fit("gaus", "", "", -0.1, 0.1)
0051     hist.Write()
0052 out_file.Close()
0053 
0054 print("✅ 分 pt 区间的相对误差直方图已保存到 pt_relative_error_bybin.root")