File indexing completed on 2025-12-16 09:18:05
0001 import numpy as np
0002 import torch
0003 from torch.utils.data import Dataset
0004
0005 import uproot
0006 from collections import Counter
0007 from numpy.linalg import norm
0008 from sklearn.preprocessing import StandardScaler
0009
0010 import ROOT
0011
0012 def angle_between(v1, v2):
0013 cos_theta = np.dot(v1, v2) / (norm(v1) * norm(v2))
0014 cos_theta = np.clip(cos_theta, -1.0, 1.0)
0015 return np.arccos(cos_theta)
0016
0017 def point_line_distance(point, line_point1, line_point2):
0018 """
0019 计算 point 到 line_point1-line_point2 构成的直线的距离
0020 """
0021 point = np.array(point)
0022 line_point1 = np.array(line_point1)
0023 line_point2 = np.array(line_point2)
0024 line_vec = line_point2 - line_point1
0025 return norm(np.cross(point - line_point1, line_vec)) / norm(line_vec)
0026
0027 class TrackCaloDataset(Dataset):
0028 def __init__(self, list_file, tree_name="tree", scaler=None, pt_min=0.0, pt_max=10.0):
0029 """
0030 自定义 Dataset,传入 list 文件路径即可,自动提取特征。
0031 如果传入 scaler,则使用给定 scaler 标准化;否则自己创建一个并 fit。
0032 """
0033
0034 data_X, data_Y = self.extract_features_from_rootlist(list_file, tree_name)
0035
0036
0037 data_X = np.array(data_X)
0038 data_Y = np.array(data_Y)
0039 mask = (data_Y >= pt_min) & (data_Y < pt_max)
0040 data_X = data_X[mask]
0041 data_Y = data_Y[mask]
0042
0043
0044
0045
0046
0047
0048
0049
0050 self.X = torch.tensor(data_X, dtype=torch.float32)
0051 self.Y = torch.tensor(data_Y, dtype=torch.float32).view(-1, 1)
0052 self.valid_indices = []
0053
0054
0055 def __len__(self):
0056 return len(self.X)
0057
0058 def __getitem__(self, idx):
0059 return self.X[idx], self.Y[idx]
0060
0061 @staticmethod
0062 def extract_features_from_rootlist(list_file, tree_name="tree"):
0063 fail_012 = 0
0064 fail_34 = 0
0065 fail_56 = 0
0066 fail_calo = 0
0067 fail_truth = 0
0068
0069 branches_to_load = [
0070 "trk_system", "trk_layer", "trk_X", "trk_Y", "trk_Z",
0071 "caloClus_system", "caloClus_X", "caloClus_Y", "caloClus_Z", "caloClus_edep",
0072 "caloClus_innr_X", "caloClus_innr_Y", "caloClus_innr_Z", "caloClus_innr_edep",
0073 "PrimaryG4P_Pt"
0074 ]
0075
0076 X_data = []
0077 Y_data = []
0078
0079 with open(list_file, "r") as f:
0080 root_files = [line.strip() for line in f if line.strip()]
0081
0082
0083 file_C = ROOT.TFile.Open(
0084 "/mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/"
0085 "ParticleGen/output/calo_positron_dphi_ptbin_woC.root", "READ"
0086 )
0087 g1_dphi_C = file_C.Get("grPeakVsX")
0088
0089 for root_file in root_files:
0090 try:
0091 file = uproot.open(root_file)
0092 tree = file[tree_name]
0093 data = tree.arrays(branches_to_load, library="np")
0094
0095 n_entries = len(data["trk_system"])
0096
0097 for i in range(n_entries):
0098 trk_layer = data["trk_layer"][i]
0099 trk_x = data["trk_X"][i]
0100 trk_y = data["trk_Y"][i]
0101 trk_z = data["trk_Z"][i]
0102
0103 trk_hits = list(zip(trk_layer, trk_x, trk_y))
0104
0105
0106 clu_34 = [p for p in trk_hits if p[0] in (3, 4)]
0107 clu_56 = [p for p in trk_hits if p[0] in (5, 6)]
0108 if len(clu_34) != 1 or len(clu_56) != 1:
0109 continue
0110
0111 p34 = np.array(clu_34[0][1:3])
0112 p56 = np.array(clu_56[0][1:3])
0113
0114 phi34 = np.arctan2(p34[1], p34[0])
0115 phi56 = np.arctan2(p56[1], p56[0])
0116
0117
0118
0119 track_point_layers = []
0120 success = True
0121 for layer_id in [0, 1, 2]:
0122 layer_hits = [p for p in trk_hits if p[0] == layer_id]
0123 if len(layer_hits) == 0:
0124 success = False
0125 break
0126 dists = [point_line_distance(p[1:], p34, p56) for p in layer_hits]
0127 min_idx = np.argmin(dists)
0128 track_point_layers.append(layer_hits[min_idx][1:])
0129
0130 if not success:
0131 continue
0132
0133 trk_feat = np.concatenate([
0134
0135 p34,
0136 p56
0137 ])
0138
0139
0140 calo_system = data["caloClus_system"][i]
0141 calo_x = data["caloClus_X"][i]
0142 calo_y = data["caloClus_Y"][i]
0143 calo_z = data["caloClus_Z"][i]
0144 calo_e = data["caloClus_edep"][i]
0145
0146 if len(calo_system) == 0:
0147 fail_calo += 1
0148 continue
0149
0150 if not (calo_system[0] == 0 and np.sum(calo_system == 0) == 1):
0151 fail_calo += 1
0152 continue
0153
0154 calo_feat = np.array([calo_x[0], calo_y[0], calo_z[0], calo_e[0]])
0155
0156
0157 calo_innr_x = data["caloClus_innr_X"][i]
0158 calo_innr_y = data["caloClus_innr_Y"][i]
0159 calo_innr_z = data["caloClus_innr_Z"][i]
0160 calo_innr_e = data["caloClus_innr_edep"][i]
0161 if len(calo_innr_x) != 1:
0162 fail_calo += 1
0163 continue
0164
0165
0166
0167
0168 phi_calo = np.arctan2(calo_innr_y[0], calo_innr_x[0])
0169
0170
0171 pcalo = np.array([calo_x[0], calo_y[0]])
0172
0173
0174
0175
0176
0177
0178
0179
0180 vec1 = p56 - p34
0181 vec2 = pcalo - p56
0182 cos_theta = np.dot(vec1, vec2) / (norm(vec1) * norm(vec2))
0183 cos_theta = np.clip(cos_theta, -1.0, 1.0)
0184 angle = np.arccos(cos_theta)
0185
0186 phi1 = np.arctan2(vec1[1], vec1[0])
0187 phi2 = np.arctan2(vec2[1], vec2[0])
0188 dphi = phi2 - phi1
0189 dphi = (dphi + np.pi) % (2 * np.pi) - np.pi
0190
0191 if dphi <= 0.01:
0192 continue
0193
0194
0195 proxy_trans = 1/dphi
0196
0197
0198
0199
0200
0201 feat = np.array([proxy_trans, 0])
0202
0203
0204
0205
0206 X_data.append(feat)
0207
0208
0209 Truth_Pt = data["PrimaryG4P_Pt"][i]
0210 if len(Truth_Pt) != 1:
0211 continue
0212 Y_data.append(Truth_Pt[0])
0213
0214 except Exception as e:
0215 print(f"Error reading {root_file}: {e}")
0216 continue
0217
0218 print(f"Total usable entries: {len(X_data)}")
0219 print(f"[Stats] Events failed cond_012: {fail_012}")
0220 print(f"[Stats] Events failed cond_3or4: {fail_34}")
0221 print(f"[Stats] Events failed cond_5or6: {fail_56}")
0222
0223 return np.array(X_data), np.array(Y_data)