File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007
0008
0009 scaler = joblib.load("scaler.pkl")
0010
0011
0012 dataset = TrackCaloDataset("../test50k.list", scaler=scaler)
0013 loader = torch.utils.data.DataLoader(dataset, batch_size=128)
0014 device = "cuda" if torch.cuda.is_available() else "cpu"
0015
0016
0017 model = TrackCaloRegressor()
0018 model.load_state_dict(torch.load("best_model.pt"))
0019 model.to(device)
0020 model.eval()
0021
0022
0023 h1 = ROOT.TH1D("h_relative_error", "Relative Error: (pred - truth)/truth;Rel Error;Counts", 400, -2, 2)
0024
0025
0026 with torch.no_grad():
0027 for x, y in loader:
0028 x = x.to(device)
0029 y = y.to(device)
0030 pred_log = model(x)
0031 pred = torch.exp(pred_log)
0032 rel_err = ((pred - y) / y).cpu().numpy().flatten()
0033 for val in rel_err:
0034 h1.Fill(val)
0035
0036
0037 h1.Fit("gaus", "", "", -0.05, 0.10)
0038 fit_func = h1.GetFunction("gaus")
0039
0040 out_file = ROOT.TFile("pt_relative_error_logreco.root", "RECREATE")
0041 h1.Write()
0042 if fit_func:
0043 fit_func.Write()
0044 out_file.Close()
0045
0046 print("✅ log(pt) 模型测试完成,已写入 pt_relative_error.root")