Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 import joblib
0005 from model import TrackCaloRegressor
0006 from data import TrackCaloDataset
0007 from test_func import visualize_model
0008 
0009 device = "cuda" if torch.cuda.is_available() else "cpu"
0010 # pt_bins = [(0, 2), (2, 4), (4, 6), (6, 8), (8, 10)]
0011 pt_bins = [(0, 10)]
0012 
0013 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI.root", "RECREATE")
0014 
0015 name_010 = f"h2_pt_vs_relerr_all"
0016 title_010 = f"pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth"
0017 hist_010 = ROOT.TH2D(name_010, title_010, 100, 0, 10, 250, -2., 2.)
0018 
0019 for pt_min, pt_max in pt_bins:
0020     print(f"\n=== Testing model for pt in [{pt_min}, {pt_max}) GeV ===")
0021 
0022     # 加载 scaler 和数据(注意要传 pt 范围)
0023     # scaler = joblib.load(f"model_weight/scaler_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pkl")
0024     # scaler = joblib.load(f"model_weight/scaler_pt_0.0_10.0_INTT_CaloI.pkl")
0025     dataset = TrackCaloDataset("../test500k.list", scaler=None, pt_min=pt_min, pt_max=pt_max)
0026     loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0027 
0028     # 加载模型
0029     model = TrackCaloRegressor()
0030     model.load_state_dict(torch.load(f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt"))
0031     # model.load_state_dict(torch.load(f"model_weight/best_model_pt_0.0_10.0_INTT_CaloI.pt"))
0032     model.to(device)
0033     model.eval()
0034 
0035     # visualize_model(model, device, x_min=0.01, x_max=0.5, num_points=500)
0036 
0037     # 直方图
0038     name = f"h_relerr_{pt_min}_{pt_max}"
0039     title = f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts"
0040     hist = ROOT.TH1D(name, title, 200, -1, 1)
0041 
0042     # 2D 直方图
0043     name2d = f"h2_pt_vs_relerr_{pt_min}_{pt_max}"
0044     title2d = f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth"
0045     nbins_2dy = (pt_max - pt_min) * 10
0046     hist2d = ROOT.TH2D(name2d, title2d, int(nbins_2dy), float(pt_min), float(pt_max), 250, -2., 2.)
0047 
0048     # 推理并填充
0049     with torch.no_grad():
0050         for x, y in loader:
0051             x = x.to(device)
0052             y = y.to(device)
0053             pred = model(x)
0054 
0055             y_np = y.cpu().numpy().flatten()
0056             pred_np = pred.cpu().numpy().flatten()
0057             rel_err = (pred_np - y_np) / (y_np)
0058 
0059             for yi, rei in zip(y_np, rel_err):
0060                 hist.Fill(rei)
0061                 hist2d.Fill(yi, rei)
0062                 hist_010.Fill(yi, rei)
0063 
0064     # hist.Fit("gaus", "", "", -0.1, 0.1)
0065     out_file.cd()
0066     hist.Write()
0067     hist2d.Write()
0068 
0069 out_file.cd()    
0070 hist_010.Write()
0071 
0072 out_file.Close()
0073 print("✅ 每个 pt 区间的相对误差已保存为 pt_relative_error_bybin.root")