Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 
0002 import numpy as np
0003 import torch
0004 from torch.utils.data import Dataset
0005 import uproot
0006 from numpy.linalg import norm
0007 
0008 def point_line_distance(point, line_point1, line_point2):
0009     """
0010     计算 point 到 line_point1-line_point2 构成的直线的距离
0011     """
0012     point = np.array(point)
0013     line_point1 = np.array(line_point1)
0014     line_point2 = np.array(line_point2)
0015     line_vec = line_point2 - line_point1
0016     return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0017 
0018 class TrackPointNetDataset(Dataset):
0019     def __init__(self, list_file, tree_name="tree", max_hits=6, pt_min=0.0, pt_max=10.0):
0020         self.samples = []
0021         self.targets = []
0022         self.max_hits = max_hits
0023 
0024         fail_012 = 0
0025         fail_34 = 0
0026         fail_56 = 0
0027         fail_calo = 0
0028         fail_truth = 0
0029 
0030         branches_to_load = [
0031             "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0032             "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep", 
0033             "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0034             "PrimaryG4P_Pt"
0035         ]
0036 
0037         with open(list_file, "r") as f:
0038             root_files = [line.strip() for line in f if line.strip()]
0039 
0040         for root_file in root_files:
0041             try:
0042                 file = uproot.open(root_file)
0043                 tree = file[tree_name]
0044                 data = tree.arrays(branches_to_load, library="np")
0045                 n_entries = len(data["trk_system"])
0046 
0047                 for i in range(n_entries):
0048                     trk_layer = data["trk_layer"][i]
0049                     trk_x = data["trk_X"][i]
0050                     trk_y = data["trk_Y"][i]
0051                     trk_z = data["trk_Z"][i]
0052                     trk_hits = list(zip(trk_layer, trk_x, trk_y, trk_z))
0053                     # trk_hits = [
0054                     #     (layer0, x0, y0, z0),
0055                     #     (layer1, x1, y1, z1),
0056                     #     ...
0057                     # ]
0058 
0059                     clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0060                     clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0061                     if len(clu_34) != 1 or len(clu_56) != 1:
0062                         fail_34 += 1
0063                         continue
0064                     # 选出 layer 为 3 或 4 的 hit,存在 clu_34 列表中。
0065                     # clu_34 = [(3, 12.5, -7.3, 85.0)]  ----- (layer=3, x=12.5, y=-7.3, z=85.0)
0066 
0067                     p34 = np.array(clu_34[0][1:3])
0068                     p56 = np.array(clu_56[0][1:3])
0069                     # 取出 clu_34 列表的第一个元素,即 (3, 12.5, -7.3, 85.0),
0070                     # [1:3] -- 切片语法,表示取第 1 到第 2 个元素(第 3 个不取):only x 和 y 坐标
0071 
0072                     track_points = []  # 最终要保留的 hits
0073                     success = True 
0074                     for layer_id in [0, 1, 2]:
0075                         layer_hits = [p for p in trk_hits if p[0] == layer_id]
0076                         if len(layer_hits) == 0:
0077                             success = False
0078                             break
0079                         dists = [point_line_distance(p[1:3], p34, p56) for p in layer_hits]
0080                         min_idx = np.argmin(dists)
0081                         # track_points.append(layer_hits[min_idx][1:3])  # 取 xy
0082 
0083                     if not success:
0084                         fail_012 += 1
0085                         continue
0086 
0087                     track_points.append(p34)
0088                     track_points.append(p56)
0089                     
0090 
0091                     # # padding if needed
0092                     # pad_len = max_hits - len(track_points)
0093                     # if pad_len > 0:
0094                     #     track_points = np.vstack([track_points, np.zeros((pad_len, 2))])
0095                     # elif pad_len < 0:
0096                     #     track_points = track_points[:max_hits]
0097 
0098                     # 选取 EMCal Only 1 cluster, 加入 calo cluster 的 (x, y)
0099                     calo_system = data["caloClus_system"][i]
0100                     if len(calo_system) == 0 or not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0101                         fail_calo += 1
0102                         continue
0103 
0104                     calo_x = data["caloClus_X"][i][0]
0105                     calo_y = data["caloClus_Y"][i][0]
0106                     calo_point = np.array([calo_x, calo_y])
0107                     track_points.append(calo_point)
0108 
0109                     track_points = np.array(track_points)  # [N_hits, 2] (最多6个)
0110 
0111                     pt_truth = data["PrimaryG4P_Pt"][i]
0112                     if len(pt_truth) != 1:
0113                         fail_truth += 1
0114                         continue
0115 
0116                     pt_value = pt_truth[0]
0117                     if not (pt_min <= pt_value < pt_max):
0118                         continue
0119 
0120                     self.samples.append(track_points)
0121                     self.targets.append(pt_value)
0122 
0123             except Exception as e:
0124                 print(f"Error reading {root_file}: {e}")
0125                 continue
0126 
0127         print(f"✅ Total usable entries: {len(self.samples)}")
0128         print(f"[Stats] Events failed 012: {fail_012}")
0129         print(f"[Stats] Events failed 34/56: {fail_34}")
0130         print(f"[Stats] Events failed calo: {fail_calo}")
0131         print(f"[Stats] Events failed truth: {fail_truth}")
0132 
0133         self.samples = np.array(self.samples)
0134         self.targets = np.array(self.targets)
0135 
0136     def __len__(self):
0137         return len(self.samples)
0138 
0139     def __getitem__(self, idx):
0140         x = torch.tensor(self.samples[idx], dtype=torch.float32)  # [max_hits, 2]
0141         y = torch.tensor(self.targets[idx], dtype=torch.float32)  # scalar
0142         return x, y