Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:05

0001 ### data_combined.py
0002 import torch
0003 from torch.utils.data import Dataset
0004 import numpy as np
0005 import joblib
0006 import ROOT
0007 
0008 from data_dphi import TrackCaloDataset as dphiDataset
0009 from data_energy import TrackCaloDataset as energyDataset
0010 from model_dphi import TrackCaloRegressor as dphiRegressor
0011 from model_energy import TrackCaloRegressor as energyRegressor
0012 
0013 class FusionDataset(Dataset):
0014     def __init__(self,
0015                  list_file,
0016                  pt_min=0.0,
0017                  pt_max=10.0,
0018                  device="cuda" if torch.cuda.is_available() else "cpu"):
0019         self.device = device
0020         self.pt_min  = pt_min
0021         self.pt_max  = pt_max
0022 
0023         # —— 加载并冻结子模型 —— #
0024         model_dphi = dphiRegressor().to(device)
0025         model_dphi.load_state_dict(torch.load(
0026             "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/"
0027             "model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt",
0028             map_location=device
0029         ))
0030         model_dphi.eval()
0031 
0032         model_energy = energyRegressor().to(device)
0033         model_energy.load_state_dict(torch.load(
0034             "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/"
0035             "model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt",
0036             map_location=device
0037         ))
0038         model_energy.eval()
0039 
0040         # —— 加载 energy 子模型的 scaler —— #
0041         scaler_energy = joblib.load(
0042             "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/"
0043             "model_weight/scaler_pt_0.0_10.0_INTT_CaloIwoE.pkl"
0044         )
0045 
0046         # —— 准备原始子数据集 —— #
0047         ds_dphi   = dphiDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0048         ds_energy = energyDataset(list_file, scaler=scaler_energy, pt_min=pt_min, pt_max=pt_max)
0049 
0050         assert len(ds_dphi) == len(ds_energy)
0051         N = len(ds_dphi)
0052 
0053         # —— 构造融合特征 & 标签 & 子模型直方图 —— #
0054         X_fusion = []
0055         Y_fusion = []
0056 
0057         # —— 遍历填充 —— #
0058         with torch.no_grad():
0059             for i in range(N):
0060                 x1, y  = ds_dphi[i]
0061                 x2, _  = ds_energy[i]
0062                 truth  = y.item()
0063 
0064                 # 输出前两个样本的原始特征和真值用于检查
0065                 if i < 2:
0066                     print(f"[Sample {i}] x1={x1.tolist()}, x2={x2.tolist()}, truth={truth:.4f}")
0067 
0068                 # 子模型预测
0069                 pt_pred1 = model_dphi(x1.unsqueeze(0).to(device)).cpu().item()
0070                 pt_pred2 = model_energy(x2.unsqueeze(0).to(device)).cpu().item()
0071 
0072                 X_fusion.append([pt_pred1, pt_pred2])
0073                 Y_fusion.append(truth)
0074 
0075         # —— 转为 Tensor —— #
0076         self.X = torch.tensor(np.array(X_fusion), dtype=torch.float32)
0077         self.Y = torch.tensor(np.array(Y_fusion), dtype=torch.float32).view(-1, 1)
0078 
0079     def __len__(self):
0080         return len(self.X)
0081 
0082     def __getitem__(self, idx):
0083         return self.X[idx], self.Y[idx]