Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:05

0001 import numpy as np
0002 import torch
0003 from torch.utils.data import Dataset
0004 # from sklearn.preprocessing import StandardScaler
0005 import uproot
0006 from collections import Counter
0007 from numpy.linalg import norm
0008 from sklearn.preprocessing import StandardScaler
0009 
0010 def point_line_distance(point, line_point1, line_point2):
0011     """
0012     计算 point 到 line_point1-line_point2 构成的直线的距离
0013     """
0014     point = np.array(point)
0015     line_point1 = np.array(line_point1)
0016     line_point2 = np.array(line_point2)
0017     line_vec = line_point2 - line_point1
0018     return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0019 
0020 class TrackCaloDataset(Dataset):
0021     def __init__(self, list_file, tree_name="tree", scaler=None, pt_min=0.0, pt_max=10.0):
0022         """
0023         自定义 Dataset,传入 list 文件路径即可,自动提取特征。
0024         如果传入 scaler,则使用给定 scaler 标准化;否则自己创建一个并 fit。
0025         """
0026         # 提取数据
0027         data_X, data_Y = self.extract_features_from_rootlist(list_file, tree_name)
0028 
0029         # 筛选 pt 区间
0030         data_X = np.array(data_X)
0031         data_Y = np.array(data_Y)
0032         mask = (data_Y >= pt_min) & (data_Y < pt_max)
0033         data_X = data_X[mask]
0034         data_Y = data_Y[mask]
0035 
0036         # 标准化
0037         if scaler is None:
0038             scaler = StandardScaler()
0039             data_X = scaler.fit_transform(data_X)
0040         else:
0041             data_X = scaler.transform(data_X)
0042 
0043         self.X = torch.tensor(data_X, dtype=torch.float32)
0044         self.Y = torch.tensor(data_Y, dtype=torch.float32).view(-1, 1)
0045         self.scaler = scaler  # 保存 scaler 以便导出或后续使用
0046 
0047     def __len__(self):
0048         return len(self.X)
0049 
0050     def __getitem__(self, idx):
0051         return self.X[idx], self.Y[idx]
0052 
0053     @staticmethod
0054     def extract_features_from_rootlist(list_file, tree_name="tree"):
0055         fail_012 = 0
0056         fail_34 = 0
0057         fail_56 = 0
0058         fail_calo = 0
0059         fail_truth = 0
0060 
0061         branches_to_load = [
0062             "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0063             "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep", 
0064             "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0065             "PrimaryG4P_Pt"
0066         ]
0067 
0068         X_data = []
0069         Y_data = []
0070 
0071         with open(list_file, "r") as f:
0072             root_files = [line.strip() for line in f if line.strip()]
0073 
0074         for root_file in root_files:
0075             try:
0076                 file = uproot.open(root_file)
0077                 tree = file[tree_name]
0078                 data = tree.arrays(branches_to_load, library="np")
0079 
0080                 n_entries = len(data["trk_system"])  # 每个 entry 是一个 event
0081 
0082                 for i in range(n_entries):
0083                     trk_layer = data["trk_layer"][i]
0084                     trk_x = data["trk_X"][i]
0085                     trk_y = data["trk_Y"][i]
0086                     trk_z = data["trk_Z"][i]
0087                     
0088                     trk_hits = list(zip(trk_layer, trk_x, trk_y, trk_z))
0089 
0090                     # 找到3/4层和5/6层的击中
0091                     clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0092                     clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0093                     if len(clu_34) != 1 or len(clu_56) != 1:
0094                         continue  # 要求唯一匹配的34、56
0095                     
0096                     p34 = np.array(clu_34[0][1:])
0097                     p56 = np.array(clu_56[0][1:])
0098 
0099                     # 对于每层0/1/2,找到距离 p34-p56 线最近的 cluster
0100                     track_point_layers = []
0101                     success = True
0102                     for layer_id in [0, 1, 2]:
0103                         layer_hits = [p for p in trk_hits if p[0] == layer_id]
0104                         if len(layer_hits) == 0:
0105                             success = False
0106                             break
0107                         dists = [point_line_distance(p[1:], p34, p56) for p in layer_hits]
0108                         min_idx = np.argmin(dists)
0109                         track_point_layers.append(layer_hits[min_idx][1:])  # 只取 xyz
0110 
0111                     if not success:
0112                         continue
0113                     
0114                     # 计算 r34 和 r56
0115                     r34 = np.sqrt(p34[0]**2 + p34[1]**2)
0116                     r56 = np.sqrt(p56[0]**2 + p56[1]**2)
0117                     
0118                     # 拼接 r34, z34, r56, z56
0119                     trk_feat = np.array([
0120                         r34, p34[-1],    # INTT 3/4 层 → r, z
0121                         r56, p56[-1]     # INTT 5/6 层 → r, z
0122                     ])
0123 
0124                     # trk_feat = np.concatenate([
0125                     #     # np.array(track_point_layers).flatten(),  # 0/1/2 层的 9 维
0126                     #     p34,                                     # INTT 3/4 层的 3 维
0127                     #     p56                                      # INTT 5/6 层的 3 维
0128                     # ])
0129 
0130                     # calo geom center cluster
0131                     calo_system = data["caloClus_system"][i]
0132                     calo_x = data["caloClus_X"][i]
0133                     calo_y = data["caloClus_Y"][i]
0134                     calo_z = data["caloClus_Z"][i]
0135                     calo_e = data["caloClus_edep"][i]
0136 
0137                     if len(calo_system) == 0:
0138                         fail_calo += 1
0139                         continue
0140 
0141                     if not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0142                         fail_calo += 1
0143                         continue
0144                     
0145                     calo_feat = np.array([calo_x[0], calo_y[0], calo_z[0], calo_e[0]])
0146 
0147                     # calo innr center cluster
0148                     calo_innr_x = data["caloClus_innr_X"][i]
0149                     calo_innr_y = data["caloClus_innr_Y"][i]
0150                     calo_innr_z = data["caloClus_innr_Z"][i]
0151                     calo_innr_e = data["caloClus_innr_edep"][i]
0152                     if len(calo_innr_x) != 1:
0153                         fail_calo += 1
0154                         continue
0155                     
0156                     # calo_innr_feat = np.array([
0157                     #     calo_innr_x[0], calo_innr_y[0], calo_innr_z[0], calo_innr_e[0]
0158                     # ])
0159 
0160                     # 计算 r_innr
0161                     r_innr = np.sqrt(calo_innr_x[0]**2 + calo_innr_y[0]**2)
0162 
0163                     calo_innr_feat = np.array([
0164                         r_innr,
0165                         calo_innr_z[0],
0166                         calo_innr_e[0]
0167                     ])
0168 
0169                     # 计算角度和 phi 差
0170                     pcalo_4align = np.array([calo_x[0], calo_y[0]]) 
0171                     p34_4align = np.array(clu_34[0][1:3])
0172                     p56_4align = np.array(clu_56[0][1:3])
0173 
0174                     vec1 = p56_4align - p34_4align
0175                     vec2 = pcalo_4align - p56_4align
0176                     cos_theta = np.dot(vec1, vec2) / (norm(vec1) * norm(vec2))
0177                     cos_theta = np.clip(cos_theta, -1.0, 1.0)
0178                     angle = np.arccos(cos_theta)
0179 
0180                     phi1 = np.arctan2(vec1[1], vec1[0])
0181                     phi2 = np.arctan2(vec2[1], vec2[0])
0182                     dphi = phi2 - phi1
0183                     dphi = (dphi + np.pi) % (2 * np.pi) - np.pi
0184                     
0185                     if dphi <= 0.01:
0186                         continue  # 直接 skip 当前 event
0187 
0188                     # feat = np.concatenate([trk_feat, calo_feat, calo_innr_feat])
0189                     feat = np.concatenate([trk_feat, calo_innr_feat])
0190 
0191                     X_data.append(feat)
0192 
0193                     # Y: PrimaryG4P_Pt - truth pt 
0194                     Truth_Pt = data["PrimaryG4P_Pt"][i]
0195                     if len(Truth_Pt) != 1:
0196                         continue
0197                     Y_data.append(Truth_Pt[0])
0198 
0199             except Exception as e:
0200                 print(f"Error reading {root_file}: {e}")
0201                 continue
0202 
0203         print(f"Total usable entries: {len(X_data)}")
0204         print(f"[Stats] Events failed cond_012: {fail_012}")
0205         print(f"[Stats] Events failed cond_3or4: {fail_34}")
0206         print(f"[Stats] Events failed cond_5or6: {fail_56}")
0207 
0208         return np.array(X_data), np.array(Y_data)