File indexing completed on 2025-12-16 09:18:05
0001 import numpy as np
0002 import torch
0003 from torch.utils.data import Dataset
0004
0005 import uproot
0006 from collections import Counter
0007 from numpy.linalg import norm
0008 from sklearn.preprocessing import StandardScaler
0009
0010 def point_line_distance(point, line_point1, line_point2):
0011 """
0012 计算 point 到 line_point1-line_point2 构成的直线的距离
0013 """
0014 point = np.array(point)
0015 line_point1 = np.array(line_point1)
0016 line_point2 = np.array(line_point2)
0017 line_vec = line_point2 - line_point1
0018 return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0019
0020 class TrackCaloDataset(Dataset):
0021 def __init__(self, list_file, tree_name="tree", scaler=None, pt_min=0.0, pt_max=10.0):
0022 """
0023 自定义 Dataset,传入 list 文件路径即可,自动提取特征。
0024 如果传入 scaler,则使用给定 scaler 标准化;否则自己创建一个并 fit。
0025 """
0026
0027 data_X, data_Y = self.extract_features_from_rootlist(list_file, tree_name)
0028
0029
0030 data_X = np.array(data_X)
0031 data_Y = np.array(data_Y)
0032 mask = (data_Y >= pt_min) & (data_Y < pt_max)
0033 data_X = data_X[mask]
0034 data_Y = data_Y[mask]
0035
0036
0037 if scaler is None:
0038 scaler = StandardScaler()
0039 data_X = scaler.fit_transform(data_X)
0040 else:
0041 data_X = scaler.transform(data_X)
0042
0043 self.X = torch.tensor(data_X, dtype=torch.float32)
0044 self.Y = torch.tensor(data_Y, dtype=torch.float32).view(-1, 1)
0045 self.scaler = scaler
0046
0047 def __len__(self):
0048 return len(self.X)
0049
0050 def __getitem__(self, idx):
0051 return self.X[idx], self.Y[idx]
0052
0053 @staticmethod
0054 def extract_features_from_rootlist(list_file, tree_name="tree"):
0055 fail_012 = 0
0056 fail_34 = 0
0057 fail_56 = 0
0058 fail_calo = 0
0059 fail_truth = 0
0060
0061 branches_to_load = [
0062 "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0063 "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep",
0064 "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0065 "PrimaryG4P_Pt"
0066 ]
0067
0068 X_data = []
0069 Y_data = []
0070
0071 with open(list_file, "r") as f:
0072 root_files = [line.strip() for line in f if line.strip()]
0073
0074 for root_file in root_files:
0075 try:
0076 file = uproot.open(root_file)
0077 tree = file[tree_name]
0078 data = tree.arrays(branches_to_load, library="np")
0079
0080 n_entries = len(data["trk_system"])
0081
0082 for i in range(n_entries):
0083 trk_layer = data["trk_layer"][i]
0084 trk_x = data["trk_X"][i]
0085 trk_y = data["trk_Y"][i]
0086 trk_z = data["trk_Z"][i]
0087
0088 trk_hits = list(zip(trk_layer, trk_x, trk_y, trk_z))
0089
0090
0091 clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0092 clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0093 if len(clu_34) != 1 or len(clu_56) != 1:
0094 continue
0095
0096 p34 = np.array(clu_34[0][1:])
0097 p56 = np.array(clu_56[0][1:])
0098
0099
0100 track_point_layers = []
0101 success = True
0102 for layer_id in [0, 1, 2]:
0103 layer_hits = [p for p in trk_hits if p[0] == layer_id]
0104 if len(layer_hits) == 0:
0105 success = False
0106 break
0107 dists = [point_line_distance(p[1:], p34, p56) for p in layer_hits]
0108 min_idx = np.argmin(dists)
0109 track_point_layers.append(layer_hits[min_idx][1:])
0110
0111 if not success:
0112 continue
0113
0114
0115 r34 = np.sqrt(p34[0]**2 + p34[1]**2)
0116 r56 = np.sqrt(p56[0]**2 + p56[1]**2)
0117
0118
0119 trk_feat = np.array([
0120 r34, p34[-1],
0121 r56, p56[-1]
0122 ])
0123
0124
0125
0126
0127
0128
0129
0130
0131 calo_system = data["caloClus_system"][i]
0132 calo_x = data["caloClus_X"][i]
0133 calo_y = data["caloClus_Y"][i]
0134 calo_z = data["caloClus_Z"][i]
0135 calo_e = data["caloClus_edep"][i]
0136
0137 if len(calo_system) == 0:
0138 fail_calo += 1
0139 continue
0140
0141 if not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0142 fail_calo += 1
0143 continue
0144
0145 calo_feat = np.array([calo_x[0], calo_y[0], calo_z[0], calo_e[0]])
0146
0147
0148 calo_innr_x = data["caloClus_innr_X"][i]
0149 calo_innr_y = data["caloClus_innr_Y"][i]
0150 calo_innr_z = data["caloClus_innr_Z"][i]
0151 calo_innr_e = data["caloClus_innr_edep"][i]
0152 if len(calo_innr_x) != 1:
0153 fail_calo += 1
0154 continue
0155
0156
0157
0158
0159
0160
0161 r_innr = np.sqrt(calo_innr_x[0]**2 + calo_innr_y[0]**2)
0162
0163 calo_innr_feat = np.array([
0164 r_innr,
0165 calo_innr_z[0],
0166 calo_innr_e[0]
0167 ])
0168
0169
0170 pcalo_4align = np.array([calo_x[0], calo_y[0]])
0171 p34_4align = np.array(clu_34[0][1:3])
0172 p56_4align = np.array(clu_56[0][1:3])
0173
0174 vec1 = p56_4align - p34_4align
0175 vec2 = pcalo_4align - p56_4align
0176 cos_theta = np.dot(vec1, vec2) / (norm(vec1) * norm(vec2))
0177 cos_theta = np.clip(cos_theta, -1.0, 1.0)
0178 angle = np.arccos(cos_theta)
0179
0180 phi1 = np.arctan2(vec1[1], vec1[0])
0181 phi2 = np.arctan2(vec2[1], vec2[0])
0182 dphi = phi2 - phi1
0183 dphi = (dphi + np.pi) % (2 * np.pi) - np.pi
0184
0185 if dphi <= 0.01:
0186 continue
0187
0188
0189 feat = np.concatenate([trk_feat, calo_innr_feat])
0190
0191 X_data.append(feat)
0192
0193
0194 Truth_Pt = data["PrimaryG4P_Pt"][i]
0195 if len(Truth_Pt) != 1:
0196 continue
0197 Y_data.append(Truth_Pt[0])
0198
0199 except Exception as e:
0200 print(f"Error reading {root_file}: {e}")
0201 continue
0202
0203 print(f"Total usable entries: {len(X_data)}")
0204 print(f"[Stats] Events failed cond_012: {fail_012}")
0205 print(f"[Stats] Events failed cond_3or4: {fail_34}")
0206 print(f"[Stats] Events failed cond_5or6: {fail_56}")
0207
0208 return np.array(X_data), np.array(Y_data)