File indexing completed on 2025-12-16 09:18:05
0001 import torch
0002 import ROOT
0003 import numpy as np
0004
0005 from model_combined import FusionRegressor
0006 from data_combined import FusionDataset
0007
0008 def test_fusion(
0009 list_file,
0010 model_path="model_weight/best_model_combined.pt",
0011 output_path="outputFile/pt_relative_error_combined.root",
0012 pt_bins=[(0.0, 10.0)],
0013 batch_size=512,
0014 device="cuda" if torch.cuda.is_available() else "cpu"
0015 ):
0016 print(f"Using device: {device}")
0017 out_file = ROOT.TFile(output_path, "RECREATE")
0018
0019 for pt_min, pt_max in pt_bins:
0020 print(f"\n=== Testing FusionRegressor for pt in [{pt_min}, {pt_max}) GeV ===")
0021
0022 dataset = FusionDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0023 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
0024
0025
0026 model = FusionRegressor(hidden_dim=256).to(device)
0027 model.load_state_dict(torch.load(model_path, map_location=device))
0028 model.eval()
0029
0030
0031 nbins_pt = int((pt_max - pt_min) * 10)
0032 hist2d = ROOT.TH2D(
0033 f"h2_pt_vs_relerr_{pt_min}_{pt_max}",
0034 f"pt vs RelErr Combined;truth pt;(pred - truth)/truth",
0035 nbins_pt, pt_min, pt_max,
0036 250, -2.0, 2.0
0037 )
0038 hist1d = ROOT.TH1D(
0039 f"h_relerr_{pt_min}_{pt_max}",
0040 f"RelErr Combined;(pred - truth)/truth;Counts",
0041 200, -1.0, 1.0
0042 )
0043
0044
0045 with torch.no_grad():
0046 for xb, yb in loader:
0047 xb, yb = xb.to(device), yb.to(device)
0048 pred = model(xb)
0049
0050 y_np = yb.cpu().numpy().flatten()
0051 pred_np = pred.cpu().numpy().flatten()
0052
0053
0054
0055
0056
0057
0058
0059 rel_err = (pred_np - y_np) / y_np
0060
0061 for truth_pt, rei in zip(y_np, rel_err):
0062 hist1d.Fill(rei)
0063 hist2d.Fill(truth_pt, rei)
0064
0065
0066 out_file.cd()
0067 hist1d.Write()
0068 hist2d.Write()
0069
0070 out_file.Close()
0071 print(f"✅ Test complete. Results saved to {output_path}")
0072
0073
0074 if __name__ == "__main__":
0075 test_fusion(
0076 list_file="../test500k.list",
0077 model_path="model_weight/best_model_combined.pt",
0078 output_path="outputFile/pt_relative_error_combined.root",
0079 pt_bins=[(0.0, 10.0)],
0080 batch_size=512
0081 )