Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:18:16

0001 import torch
0002 import ROOT
0003 import numpy as np
0004 from model_pointnet import PointNetRegressor  # 你的 PointNet 模型
0005 from data_pointnet import TrackPointNetDataset  # 你的 Dataset
0006 
0007 device = "cuda" if torch.cuda.is_available() else "cpu"
0008 pt_bins = [(0, 10)]  # 你可以扩展为多个 pt 区间
0009 
0010 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI_pointnet.root", "RECREATE")
0011 
0012 # 汇总直方图
0013 hist_010 = ROOT.TH2D("h2_pt_vs_relerr_all", "pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth",
0014                      100, 0, 10, 250, -2., 2.)
0015 
0016 for pt_min, pt_max in pt_bins:
0017     print(f"\n=== Testing PointNet for pt in [{pt_min}, {pt_max}) GeV ===")
0018 
0019     dataset = TrackPointNetDataset("../test50k.list", pt_min=pt_min, pt_max=pt_max)
0020     loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0021 
0022 
0023     # 加载 PointNet 模型
0024     model = PointNetRegressor()
0025     model.load_state_dict(torch.load(f"model_weight/best_pointnet_pt_{pt_min:.1f}_{pt_max:.1f}.pt"))
0026     model.to(device)
0027     model.eval()
0028 
0029     # 相对误差直方图
0030     hist = ROOT.TH1D(f"h_relerr_{pt_min}_{pt_max}",
0031                      f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts",
0032                      200, -1, 1)
0033 
0034     # pt vs 相对误差 2D图
0035     nbins_2dy = (pt_max - pt_min) * 10
0036     hist2d = ROOT.TH2D(f"h2_pt_vs_relerr_{pt_min}_{pt_max}",
0037                        f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth",
0038                        nbins_2dy, pt_min, pt_max, 250, -2., 2.)
0039 
0040     # 推理 + 填充
0041     with torch.no_grad():
0042         for x, y in loader:
0043             x = x.to(device)
0044             y = y.to(device)
0045             pred = model(x)
0046 
0047             y_np = y.cpu().numpy().flatten()
0048             pred_np = pred.cpu().numpy().flatten()
0049             rel_err = (pred_np - y_np) / y_np
0050 
0051             for yi, rei in zip(y_np, rel_err):
0052                 hist.Fill(rei)
0053                 hist2d.Fill(yi, rei)
0054                 hist_010.Fill(yi, rei)
0055 
0056     hist.Fit("gaus", "", "", -0.1, 0.1)
0057     hist.Write()
0058     hist2d.Write()
0059 
0060 hist_010.Write()
0061 out_file.Close()
0062 
0063 print("✅ 每个 pt 区间的 PointNet 相对误差已保存为 pt_relative_error_INTT_CaloI_pointnet.root")