File indexing completed on 2025-12-17 09:18:16
0001 import torch
0002 import ROOT
0003 import numpy as np
0004 from model_pointnet import PointNetRegressor
0005 from data_pointnet import TrackPointNetDataset
0006
0007 device = "cuda" if torch.cuda.is_available() else "cpu"
0008 pt_bins = [(0, 10)]
0009
0010 out_file = ROOT.TFile("outputFile/pt_relative_error_INTT_CaloI_pointnet.root", "RECREATE")
0011
0012
0013 hist_010 = ROOT.TH2D("h2_pt_vs_relerr_all", "pt vs RelErr for 0 < pt < 10;truth pt;(pred - truth)/truth",
0014 100, 0, 10, 250, -2., 2.)
0015
0016 for pt_min, pt_max in pt_bins:
0017 print(f"\n=== Testing PointNet for pt in [{pt_min}, {pt_max}) GeV ===")
0018
0019 dataset = TrackPointNetDataset("../test50k.list", pt_min=pt_min, pt_max=pt_max)
0020 loader = torch.utils.data.DataLoader(dataset, batch_size=256)
0021
0022
0023
0024 model = PointNetRegressor()
0025 model.load_state_dict(torch.load(f"model_weight/best_pointnet_pt_{pt_min:.1f}_{pt_max:.1f}.pt"))
0026 model.to(device)
0027 model.eval()
0028
0029
0030 hist = ROOT.TH1D(f"h_relerr_{pt_min}_{pt_max}",
0031 f"RelErr for {pt_min} < pt < {pt_max};(pred - truth)/truth;Counts",
0032 200, -1, 1)
0033
0034
0035 nbins_2dy = (pt_max - pt_min) * 10
0036 hist2d = ROOT.TH2D(f"h2_pt_vs_relerr_{pt_min}_{pt_max}",
0037 f"pt vs RelErr for {pt_min} < pt < {pt_max};truth pt;(pred - truth)/truth",
0038 nbins_2dy, pt_min, pt_max, 250, -2., 2.)
0039
0040
0041 with torch.no_grad():
0042 for x, y in loader:
0043 x = x.to(device)
0044 y = y.to(device)
0045 pred = model(x)
0046
0047 y_np = y.cpu().numpy().flatten()
0048 pred_np = pred.cpu().numpy().flatten()
0049 rel_err = (pred_np - y_np) / y_np
0050
0051 for yi, rei in zip(y_np, rel_err):
0052 hist.Fill(rei)
0053 hist2d.Fill(yi, rei)
0054 hist_010.Fill(yi, rei)
0055
0056 hist.Fit("gaus", "", "", -0.1, 0.1)
0057 hist.Write()
0058 hist2d.Write()
0059
0060 hist_010.Write()
0061 out_file.Close()
0062
0063 print("✅ 每个 pt 区间的 PointNet 相对误差已保存为 pt_relative_error_INTT_CaloI_pointnet.root")