File indexing completed on 2025-12-16 09:18:06
0001
0002 import numpy as np
0003 import torch
0004 from torch.utils.data import Dataset
0005 import uproot
0006 from numpy.linalg import norm
0007
0008 def point_line_distance(point, line_point1, line_point2):
0009 """
0010 计算 point 到 line_point1-line_point2 构成的直线的距离
0011 """
0012 point = np.array(point)
0013 line_point1 = np.array(line_point1)
0014 line_point2 = np.array(line_point2)
0015 line_vec = line_point2 - line_point1
0016 return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0017
0018 class TrackPointNetDataset(Dataset):
0019 def __init__(self, list_file, tree_name="tree", max_hits=6, pt_min=0.0, pt_max=10.0):
0020 self.samples = []
0021 self.targets = []
0022 self.max_hits = max_hits
0023
0024 fail_012 = 0
0025 fail_34 = 0
0026 fail_56 = 0
0027 fail_calo = 0
0028 fail_truth = 0
0029
0030 branches_to_load = [
0031 "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0032 "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep",
0033 "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0034 "PrimaryG4P_Pt"
0035 ]
0036
0037 with open(list_file, "r") as f:
0038 root_files = [line.strip() for line in f if line.strip()]
0039
0040 for root_file in root_files:
0041 try:
0042 file = uproot.open(root_file)
0043 tree = file[tree_name]
0044 data = tree.arrays(branches_to_load, library="np")
0045 n_entries = len(data["trk_system"])
0046
0047 for i in range(n_entries):
0048 trk_layer = data["trk_layer"][i]
0049 trk_x = data["trk_X"][i]
0050 trk_y = data["trk_Y"][i]
0051 trk_z = data["trk_Z"][i]
0052 trk_hits = list(zip(trk_layer, trk_x, trk_y, trk_z))
0053
0054
0055
0056
0057
0058
0059 clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0060 clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0061 if len(clu_34) != 1 or len(clu_56) != 1:
0062 fail_34 += 1
0063 continue
0064
0065
0066
0067 p34 = np.array(clu_34[0][1:3])
0068 p56 = np.array(clu_56[0][1:3])
0069
0070
0071
0072 track_points = []
0073 success = True
0074 for layer_id in [0, 1, 2]:
0075 layer_hits = [p for p in trk_hits if p[0] == layer_id]
0076 if len(layer_hits) == 0:
0077 success = False
0078 break
0079 dists = [point_line_distance(p[1:3], p34, p56) for p in layer_hits]
0080 min_idx = np.argmin(dists)
0081
0082
0083 if not success:
0084 fail_012 += 1
0085 continue
0086
0087 track_points.append(p34)
0088 track_points.append(p56)
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099 calo_system = data["caloClus_system"][i]
0100 if len(calo_system) == 0 or not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0101 fail_calo += 1
0102 continue
0103
0104 calo_x = data["caloClus_X"][i][0]
0105 calo_y = data["caloClus_Y"][i][0]
0106 calo_point = np.array([calo_x, calo_y])
0107 track_points.append(calo_point)
0108
0109 track_points = np.array(track_points)
0110
0111 pt_truth = data["PrimaryG4P_Pt"][i]
0112 if len(pt_truth) != 1:
0113 fail_truth += 1
0114 continue
0115
0116 pt_value = pt_truth[0]
0117 if not (pt_min <= pt_value < pt_max):
0118 continue
0119
0120 self.samples.append(track_points)
0121 self.targets.append(pt_value)
0122
0123 except Exception as e:
0124 print(f"Error reading {root_file}: {e}")
0125 continue
0126
0127 print(f"✅ Total usable entries: {len(self.samples)}")
0128 print(f"[Stats] Events failed 012: {fail_012}")
0129 print(f"[Stats] Events failed 34/56: {fail_34}")
0130 print(f"[Stats] Events failed calo: {fail_calo}")
0131 print(f"[Stats] Events failed truth: {fail_truth}")
0132
0133 self.samples = np.array(self.samples)
0134 self.targets = np.array(self.targets)
0135
0136 def __len__(self):
0137 return len(self.samples)
0138
0139 def __getitem__(self, idx):
0140 x = torch.tensor(self.samples[idx], dtype=torch.float32)
0141 y = torch.tensor(self.targets[idx], dtype=torch.float32)
0142 return x, y