Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:05

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 
0005 from model_combined import FusionRegressor
0006 from data_combined import FusionDataset
0007 
0008 def test_fusion(
0009     list_file,
0010     model_path="model_weight/best_model_combined.pt",
0011     output_path="outputFile/pt_relative_error_combined.root",
0012     pt_bins=[(0.0, 10.0)],
0013     batch_size=512,
0014     device="cuda" if torch.cuda.is_available() else "cpu"
0015 ):
0016     print(f"Using device: {device}")
0017     out_file = ROOT.TFile(output_path, "RECREATE")
0018 
0019     for pt_min, pt_max in pt_bins:
0020         print(f"\n=== Testing FusionRegressor for pt in [{pt_min}, {pt_max}) GeV ===")
0021         # —— 数据集 & DataLoader —— #
0022         dataset = FusionDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0023         loader  = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
0024 
0025         # —— 模型加载 —— #
0026         model = FusionRegressor(hidden_dim=256).to(device)
0027         model.load_state_dict(torch.load(model_path, map_location=device))
0028         model.eval()
0029 
0030         # —— 准备直方图 —— #
0031         nbins_pt = int((pt_max - pt_min) * 10)
0032         hist2d = ROOT.TH2D(
0033             f"h2_pt_vs_relerr_{pt_min}_{pt_max}",
0034             f"pt vs RelErr Combined;truth pt;(pred - truth)/truth",
0035             nbins_pt, pt_min, pt_max,
0036             250, -2.0, 2.0
0037         )
0038         hist1d = ROOT.TH1D(
0039             f"h_relerr_{pt_min}_{pt_max}",
0040             f"RelErr Combined;(pred - truth)/truth;Counts",
0041             200, -1.0, 1.0
0042         )
0043 
0044         # —— 推理并填图 —— #
0045         with torch.no_grad():
0046             for xb, yb in loader:
0047                 xb, yb = xb.to(device), yb.to(device)
0048                 pred = model(xb)
0049 
0050                 y_np    = yb.cpu().numpy().flatten()
0051                 pred_np = pred.cpu().numpy().flatten()
0052 
0053                 # (可选)对 reco_pt>=8.8 做后处理
0054                 # for i, reco_pt in enumerate(pred_np):
0055                 #     if reco_pt >= 8.8:
0056                 #         cf = 0.02 + 0.08 * (reco_pt - 8.8)
0057                 #         pred_np[i] = reco_pt * (1.0 + cf)
0058 
0059                 rel_err = (pred_np - y_np) / y_np
0060 
0061                 for truth_pt, rei in zip(y_np, rel_err):
0062                     hist1d.Fill(rei)
0063                     hist2d.Fill(truth_pt, rei)
0064 
0065         # —— 写入文件 —— #
0066         out_file.cd()
0067         hist1d.Write()
0068         hist2d.Write()
0069 
0070     out_file.Close()
0071     print(f"✅ Test complete. Results saved to {output_path}")
0072 
0073 
0074 if __name__ == "__main__":
0075     test_fusion(
0076         list_file="../test500k.list",
0077         model_path="model_weight/best_model_combined.pt",
0078         output_path="outputFile/pt_relative_error_combined.root",
0079         pt_bins=[(0.0, 10.0)],
0080         batch_size=512
0081     )