Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import numpy as np
0002 import torch
0003 from torch.utils.data import Dataset
0004 # from sklearn.preprocessing import StandardScaler
0005 import uproot
0006 from collections import Counter
0007 from numpy.linalg import norm
0008 from sklearn.preprocessing import StandardScaler
0009 
0010 import ROOT
0011 
0012 def angle_between(v1, v2):
0013     cos_theta = np.dot(v1, v2) / (norm(v1) * norm(v2))
0014     cos_theta = np.clip(cos_theta, -1.0, 1.0)  # 防止数值超出 arccos 定义域
0015     return np.arccos(cos_theta)
0016 
0017 def point_line_distance(point, line_point1, line_point2):
0018     """
0019     计算 point 到 line_point1-line_point2 构成的直线的距离
0020     """
0021     point = np.array(point)
0022     line_point1 = np.array(line_point1)
0023     line_point2 = np.array(line_point2)
0024     line_vec = line_point2 - line_point1
0025     return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0026 
0027 class TrackCaloDataset(Dataset):
0028     def __init__(self, list_file, tree_name="tree", scaler=None, pt_min=0.0, pt_max=10.0):
0029         """
0030         自定义 Dataset,传入 list 文件路径即可,自动提取特征。
0031         如果传入 scaler,则使用给定 scaler 标准化;否则自己创建一个并 fit。
0032         """
0033         # 提取数据
0034         data_X, data_Y = self.extract_features_from_rootlist(list_file, tree_name)
0035 
0036         # 筛选 pt 区间
0037         data_X = np.array(data_X)
0038         data_Y = np.array(data_Y)
0039         mask = (data_Y >= pt_min) & (data_Y < pt_max)
0040         data_X = data_X[mask]
0041         data_Y = data_Y[mask]
0042 
0043         # # 标准化
0044         # if scaler is None:
0045         #     scaler = StandardScaler()
0046         #     data_X = scaler.fit_transform(data_X)
0047         # else:
0048         #     data_X = scaler.transform(data_X)
0049 
0050         self.X = torch.tensor(data_X, dtype=torch.float32)
0051         self.Y = torch.tensor(data_Y, dtype=torch.float32).view(-1, 1)
0052         self.valid_indices = []
0053         # self.scaler = scaler  # 保存 scaler 以便导出或后续使用
0054 
0055     def __len__(self):
0056         return len(self.X)
0057 
0058     def __getitem__(self, idx):
0059         return self.X[idx], self.Y[idx]
0060 
0061     @staticmethod
0062     def extract_features_from_rootlist(list_file, tree_name="tree"):
0063         fail_012 = 0
0064         fail_34 = 0
0065         fail_56 = 0
0066         fail_calo = 0
0067         fail_truth = 0
0068 
0069         branches_to_load = [
0070             "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0071             "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep", 
0072             "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0073             "PrimaryG4P_Pt"
0074         ]
0075 
0076         X_data = []
0077         Y_data = []
0078 
0079         with open(list_file, "r") as f:
0080             root_files = [line.strip() for line in f if line.strip()]
0081 
0082         # position of EMCal correct
0083         file_C = ROOT.TFile.Open(
0084             "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/"
0085             "ParticleGen/output/calo_positron_dphi_ptbin_woC.root", "READ"
0086         )
0087         g1_dphi_C = file_C.Get("grPeakVsX") 
0088 
0089         for root_file in root_files:
0090             try:
0091                 file = uproot.open(root_file)
0092                 tree = file[tree_name]
0093                 data = tree.arrays(branches_to_load, library="np")
0094 
0095                 n_entries = len(data["trk_system"])  # 每个 entry 是一个 event
0096 
0097                 for i in range(n_entries):
0098                     trk_layer = data["trk_layer"][i]
0099                     trk_x = data["trk_X"][i]
0100                     trk_y = data["trk_Y"][i]
0101                     trk_z = data["trk_Z"][i]
0102                     
0103                     trk_hits = list(zip(trk_layer, trk_x, trk_y))
0104 
0105                     # 找到3/4层和5/6层的击中
0106                     clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0107                     clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0108                     if len(clu_34) != 1 or len(clu_56) != 1:
0109                         continue  # 要求唯一匹配的34、56
0110                     
0111                     p34 = np.array(clu_34[0][1:3])
0112                     p56 = np.array(clu_56[0][1:3])
0113 
0114                     phi34 = np.arctan2(p34[1], p34[0])  # y, x
0115                     phi56 = np.arctan2(p56[1], p56[0])
0116 
0117 
0118                     # 对于每层0/1/2,找到距离 p34-p56 线最近的 cluster
0119                     track_point_layers = []
0120                     success = True
0121                     for layer_id in [0, 1, 2]:
0122                         layer_hits = [p for p in trk_hits if p[0] == layer_id]
0123                         if len(layer_hits) == 0:
0124                             success = False
0125                             break
0126                         dists = [point_line_distance(p[1:], p34, p56) for p in layer_hits]
0127                         min_idx = np.argmin(dists)
0128                         track_point_layers.append(layer_hits[min_idx][1:])  # 只取 xyz
0129 
0130                     if not success:
0131                         continue
0132                     
0133                     trk_feat = np.concatenate([
0134                         # np.array(track_point_layers).flatten(),  # 0/1/2 层的 9 维
0135                         p34,                                     # INTT 3/4 层的 3 维
0136                         p56                                      # INTT 5/6 层的 3 维
0137                     ])
0138 
0139                     # calo geom center cluster
0140                     calo_system = data["caloClus_system"][i]
0141                     calo_x = data["caloClus_X"][i]
0142                     calo_y = data["caloClus_Y"][i]
0143                     calo_z = data["caloClus_Z"][i]
0144                     calo_e = data["caloClus_edep"][i]
0145 
0146                     if len(calo_system) == 0:
0147                         fail_calo += 1
0148                         continue
0149 
0150                     if not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0151                         fail_calo += 1
0152                         continue
0153                     
0154                     calo_feat = np.array([calo_x[0], calo_y[0], calo_z[0], calo_e[0]])
0155 
0156                     # calo innr center cluster
0157                     calo_innr_x = data["caloClus_innr_X"][i]
0158                     calo_innr_y = data["caloClus_innr_Y"][i]
0159                     calo_innr_z = data["caloClus_innr_Z"][i]
0160                     calo_innr_e = data["caloClus_innr_edep"][i]
0161                     if len(calo_innr_x) != 1:
0162                         fail_calo += 1
0163                         continue
0164                     
0165                     # calo_innr_feat = np.array([
0166                     #     calo_innr_x[0], calo_innr_y[0]
0167                     # ])
0168                     phi_calo = np.arctan2(calo_innr_y[0], calo_innr_x[0]) 
0169 
0170                     # feat setting
0171                     pcalo = np.array([calo_x[0], calo_y[0]]) 
0172                     # pcalo = np.array([calo_innr_x[0], calo_innr_y[0]]) 
0173 
0174                     # theta_correct = g1_dphi_C.Eval(calo_innr_e)
0175                     # rotation_matrix = np.array([[np.cos(theta_correct), -np.sin(theta_correct)],
0176                     #                             [np.sin(theta_correct),  np.cos(theta_correct)]])
0177                     
0178                     # pcalo = rotation_matrix.dot(pcalo)
0179 
0180                     vec1 = p56 - p34
0181                     vec2 = pcalo - p56
0182                     cos_theta = np.dot(vec1, vec2) / (norm(vec1) * norm(vec2))
0183                     cos_theta = np.clip(cos_theta, -1.0, 1.0)
0184                     angle = np.arccos(cos_theta)
0185 
0186                     phi1 = np.arctan2(vec1[1], vec1[0])
0187                     phi2 = np.arctan2(vec2[1], vec2[0])
0188                     dphi = phi2 - phi1
0189                     dphi = (dphi + np.pi) % (2 * np.pi) - np.pi
0190                     
0191                     if dphi <= 0.01:
0192                         continue  # 直接 skip 当前 event
0193 
0194                     # proxy_trans = np.log(dphi + 1e-5)
0195                     proxy_trans = 1/dphi
0196                     
0197                     # proxy_eta = -np.log(np.tan(angle / 2))
0198 
0199                     # feat = np.array([angle])  # 弧度值,范围 [0, π]
0200                     # feat = np.array([proxy_trans])
0201                     feat = np.array([proxy_trans, 0])  # 2D 特征
0202 
0203                     # feat = np.concatenate([trk_feat, calo_feat, calo_innr_feat])
0204                     # feat = np.concatenate([trk_feat, calo_innr_feat])
0205 
0206                     X_data.append(feat)
0207 
0208                     # Y: PrimaryG4P_Pt - truth pt 
0209                     Truth_Pt = data["PrimaryG4P_Pt"][i]
0210                     if len(Truth_Pt) != 1:
0211                         continue
0212                     Y_data.append(Truth_Pt[0])
0213 
0214             except Exception as e:
0215                 print(f"Error reading {root_file}: {e}")
0216                 continue
0217 
0218         print(f"Total usable entries: {len(X_data)}")
0219         print(f"[Stats] Events failed cond_012: {fail_012}")
0220         print(f"[Stats] Events failed cond_3or4: {fail_34}")
0221         print(f"[Stats] Events failed cond_5or6: {fail_56}")
0222 
0223         return np.array(X_data), np.array(Y_data)