Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007 
0008 # ==== 加载 scaler ====
0009 scaler = joblib.load("scaler.pkl")
0010 
0011 # ==== 加载测试数据 ====
0012 dataset = TrackCaloDataset("../test50k.list", scaler=scaler)
0013 loader = torch.utils.data.DataLoader(dataset, batch_size=128)
0014 device = "cuda" if torch.cuda.is_available() else "cpu"
0015 
0016 # ==== 加载模型 ====
0017 model = TrackCaloRegressor()
0018 model.load_state_dict(torch.load("best_model.pt"))
0019 model.to(device)
0020 model.eval()
0021 
0022 # ==== 创建直方图 ====
0023 h1 = ROOT.TH1D("h_relative_error", "Relative Error: (pred - truth)/truth;Rel Error;Counts", 400, -2, 2)
0024 
0025 # ==== 推理并填充 ====
0026 with torch.no_grad():
0027     for x, y in loader:
0028         x = x.to(device)
0029         y = y.to(device)
0030         pred_log = model(x)
0031         pred = torch.exp(pred_log)  # 恢复 pt
0032         rel_err = ((pred - y) / y).cpu().numpy().flatten()
0033         for val in rel_err:
0034             h1.Fill(val)
0035 
0036 # ==== 高斯拟合并保存 ====
0037 h1.Fit("gaus", "", "", -0.05, 0.10)
0038 fit_func = h1.GetFunction("gaus")
0039 
0040 out_file = ROOT.TFile("pt_relative_error_logreco.root", "RECREATE")
0041 h1.Write()
0042 if fit_func:
0043     fit_func.Write()
0044 out_file.Close()
0045 
0046 print("✅ log(pt) 模型测试完成,已写入 pt_relative_error.root")