Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007 
0008 device = "cuda" if torch.cuda.is_available() else "cpu"
0009 # pt_bins = [(0, 2), (2, 4), (4, 6), (6, 8), (8, 10)]
0010 pt_bins = [(0, 10)]
0011 
0012 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI_ontrain.root", "RECREATE")
0013 
0014 name_010 = f"h2_pt_vs_relerr_all"
0015 title_010 = f"pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth"
0016 hist_010 = ROOT.TH2D(name_010, title_010, 100, 0, 10, 250, -2., 2.)
0017 
0018 for pt_min, pt_max in pt_bins:
0019     print(f"\n=== Testing model for pt in [{pt_min}, {pt_max}) GeV ===")
0020 
0021     # 加载 scaler 和数据(注意要传 pt 范围)
0022     scaler = joblib.load(f"model_weight/scaler_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pkl")
0023     # scaler = joblib.load(f"model_weight/scaler_pt_0.0_10.0_INTT_CaloI.pkl")
0024     dataset = TrackCaloDataset("../train500k.list", scaler=scaler, pt_min=pt_min, pt_max=pt_max)
0025     loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0026 
0027     # 加载模型
0028     model = TrackCaloRegressor()
0029     model.load_state_dict(torch.load(f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt"))
0030     # model.load_state_dict(torch.load(f"model_weight/best_model_pt_0.0_10.0_INTT_CaloI.pt"))
0031     model.to(device)
0032     model.eval()
0033 
0034     # 直方图
0035     name = f"h_relerr_{pt_min}_{pt_max}"
0036     title = f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts"
0037     hist = ROOT.TH1D(name, title, 200, -1, 1)
0038 
0039     # 2D 直方图
0040     name2d = f"h2_pt_vs_relerr_{pt_min}_{pt_max}"
0041     title2d = f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth"
0042     nbins_2dy = (pt_max - pt_min) * 10
0043     hist2d = ROOT.TH2D(name2d, title2d, int(nbins_2dy), float(pt_min), float(pt_max), 250, -2., 2.)
0044 
0045     # 推理并填充
0046     with torch.no_grad():
0047         for x, y in loader:
0048             x = x.to(device)
0049             y = y.to(device)
0050             pred = model(x)
0051 
0052             y_np = y.cpu().numpy().flatten()
0053             pred_np = pred.cpu().numpy().flatten()
0054             rel_err = (pred_np - y_np) / (y_np)
0055 
0056             for yi, rei in zip(y_np, rel_err):
0057                 hist.Fill(rei)
0058                 hist2d.Fill(yi, rei)
0059                 hist_010.Fill(yi, rei)
0060 
0061     hist.Fit("gaus", "", "", -0.1, 0.1)
0062     hist.Write()
0063     hist2d.Write()
0064     
0065 hist_010.Write()
0066 
0067 out_file.Close()
0068 print("✅ 每个 pt 区间的相对误差已保存为 pt_relative_error_bybin.root")