File indexing completed on 2025-12-16 09:18:05
0001
0002 import torch
0003 from torch.utils.data import Dataset
0004 import numpy as np
0005 import joblib
0006 import ROOT
0007
0008 from data_dphi import TrackCaloDataset as dphiDataset
0009 from data_energy import TrackCaloDataset as energyDataset
0010 from model_dphi import TrackCaloRegressor as dphiRegressor
0011 from model_energy import TrackCaloRegressor as energyRegressor
0012
0013 class FusionDataset(Dataset):
0014 def __init__(self,
0015 list_file,
0016 pt_min=0.0,
0017 pt_max=10.0,
0018 device="cuda" if torch.cuda.is_available() else "cpu"):
0019 self.device = device
0020 self.pt_min = pt_min
0021 self.pt_max = pt_max
0022
0023
0024 model_dphi = dphiRegressor().to(device)
0025 model_dphi.load_state_dict(torch.load(
0026 "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/"
0027 "model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt",
0028 map_location=device
0029 ))
0030 model_dphi.eval()
0031
0032 model_energy = energyRegressor().to(device)
0033 model_energy.load_state_dict(torch.load(
0034 "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/"
0035 "model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt",
0036 map_location=device
0037 ))
0038 model_energy.eval()
0039
0040
0041 scaler_energy = joblib.load(
0042 "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/"
0043 "model_weight/scaler_pt_0.0_10.0_INTT_CaloIwoE.pkl"
0044 )
0045
0046
0047 ds_dphi = dphiDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0048 ds_energy = energyDataset(list_file, scaler=scaler_energy, pt_min=pt_min, pt_max=pt_max)
0049
0050 assert len(ds_dphi) == len(ds_energy)
0051 N = len(ds_dphi)
0052
0053
0054 X_fusion = []
0055 Y_fusion = []
0056
0057
0058 with torch.no_grad():
0059 for i in range(N):
0060 x1, y = ds_dphi[i]
0061 x2, _ = ds_energy[i]
0062 truth = y.item()
0063
0064
0065 if i < 2:
0066 print(f"[Sample {i}] x1={x1.tolist()}, x2={x2.tolist()}, truth={truth:.4f}")
0067
0068
0069 pt_pred1 = model_dphi(x1.unsqueeze(0).to(device)).cpu().item()
0070 pt_pred2 = model_energy(x2.unsqueeze(0).to(device)).cpu().item()
0071
0072 X_fusion.append([pt_pred1, pt_pred2])
0073 Y_fusion.append(truth)
0074
0075
0076 self.X = torch.tensor(np.array(X_fusion), dtype=torch.float32)
0077 self.Y = torch.tensor(np.array(Y_fusion), dtype=torch.float32).view(-1, 1)
0078
0079 def __len__(self):
0080 return len(self.X)
0081
0082 def __getitem__(self, idx):
0083 return self.X[idx], self.Y[idx]