File indexing completed on 2025-12-16 09:18:06
0001 import numpy as np
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 from data import TrackCaloDataset
0006 from model import TrackCaloRegressor
0007 import os
0008 import joblib
0009 import ROOT
0010 import matplotlib.pyplot as plt
0011
0012 def monotonic_loss(y_sorted):
0013 dy = y_sorted[1:] - y_sorted[:-1]
0014
0015 penalty = torch.relu(-dy).sum()
0016 return penalty
0017
0018 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=1024, epochs=500, lr=5e-5, val_ratio=0.2, device="cuda" if torch.cuda.is_available() else "cpu"):
0019 print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0020 print(f"Using device: {device}")
0021
0022 out_file = ROOT.TFile("outputFile/proxy_vs_pt.root", "RECREATE")
0023 hist2d = ROOT.TH2D("h2_proxy_vs_pt", "proxy vs pt;proxy;pt", 510, 2, 200, 100, 0, 10)
0024 hist_prof = ROOT.TProfile("h_profile_proxy_vs_pt", "proxy vs pt;proxy;pt", 510, -0.1, 0.5)
0025 h2_p34 = ROOT.TH2D("h2_p34", "XY of p34;X (cm);Y (cm)", 300, -30, 30, 300, -30, 30)
0026 h2_p56 = ROOT.TH2D("h2_p56", "XY of p56;X (cm);Y (cm)", 300, -30, 30, 300, -30, 30)
0027 h2_calo = ROOT.TH2D("h2_calo", "XY of calo;X (cm);Y (cm)", 600, -150, 150, 600, -150, 150)
0028 h2_SiCalo = ROOT.TH2D("h2_SiCalo", "XY of calo;X (cm);Y (cm)", 600, -150, 150, 600, -150, 150)
0029
0030 dataset = TrackCaloDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0031 train_size = int((1 - val_ratio) * len(dataset))
0032 val_size = len(dataset) - train_size
0033 train_set, val_set = random_split(dataset, [train_size, val_size])
0034
0035 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0036 val_loader = DataLoader(val_set, batch_size=batch_size)
0037
0038 model = TrackCaloRegressor().to(device)
0039 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0040 criterion = nn.MSELoss()
0041
0042 best_val_loss = float("inf")
0043
0044 lambda_mono = 0.3
0045
0046
0047 max_draw_event = 10
0048 drawn_event_count = 0
0049
0050 for epoch in range(epochs):
0051
0052 model.train()
0053 train_loss = 0
0054
0055 for xb, yb in train_loader:
0056 xb, yb = xb.to(device), yb.to(device)
0057 pred = model(xb)
0058
0059 pt_reso = (pred - yb) / (yb)
0060 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0061 main_loss = ((pt_reso) ** 2 * weights).mean()
0062
0063
0064 x1 = np.array([0, 0.5, 1, 2, 10, 15, 25, 50, 100, 200])
0065 x2 = np.zeros_like(x1)
0066
0067 x_boundary_np = np.stack([x1, x2], axis=1)
0068 x_boundary = torch.tensor(x_boundary_np, dtype=torch.float32).to(device)
0069 y_boundary_target = torch.tensor([0, 0.0961, 0.1922, 0.3844, 1.922, 2.883, 4.805, 9.61, 19.22, 38.44], dtype=torch.float32).unsqueeze(1).to(device)
0070
0071 y_boundary_pred = model(x_boundary)
0072 boundary_loss = nn.MSELoss()(y_boundary_pred, y_boundary_target)
0073
0074 lambda_boundary = min(0.005 * epoch, 0.2)
0075
0076
0077 x_sorted, indices = torch.sort(xb[:,0])
0078 pred_sorted = pred[indices]
0079 mono_penalty = monotonic_loss(pred_sorted)
0080
0081
0082 loss = main_loss + lambda_mono * mono_penalty + lambda_boundary * boundary_loss
0083
0084 optimizer.zero_grad()
0085 loss.backward()
0086 optimizer.step()
0087 train_loss += loss.item() * xb.size(0)
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138 train_loss /= len(train_loader.dataset)
0139
0140
0141 model.eval()
0142 val_loss = 0
0143 with torch.no_grad():
0144 for xb, yb in val_loader:
0145 xb, yb = xb.to(device), yb.to(device)
0146 pred = model(xb)
0147 pt_reso = (pred - yb) / (yb)
0148 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0149 main_loss = ((pt_reso) ** 2 * weights).mean()
0150
0151 x1 = np.array([0, 0.5, 1, 2, 10, 15, 25, 50, 100, 200])
0152 x2 = np.zeros_like(x1)
0153
0154 x_boundary_np = np.stack([x1, x2], axis=1)
0155 x_boundary = torch.tensor(x_boundary_np, dtype=torch.float32).to(device)
0156 y_boundary_target = torch.tensor([0, 0.0961, 0.1922, 0.3844, 1.922, 2.883, 4.805, 9.61, 19.22, 38.44], dtype=torch.float32).unsqueeze(1).to(device)
0157
0158 y_boundary_pred = model(x_boundary)
0159 boundary_loss = nn.MSELoss()(y_boundary_pred, y_boundary_target)
0160
0161 lambda_boundary = min(0.005 * epoch, 0.2)
0162
0163 x_sorted, indices = torch.sort(xb[:,0])
0164 pred_sorted = pred[indices]
0165 mono_penalty = monotonic_loss(pred_sorted)
0166
0167 loss = main_loss + lambda_mono * mono_penalty + lambda_boundary * boundary_loss
0168
0169 val_loss += loss.item() * xb.size(0)
0170
0171 val_loss /= len(val_loader.dataset)
0172
0173 print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0174
0175 if epoch >= 200 and val_loss < best_val_loss:
0176 best_val_loss = val_loss
0177 torch.save(model.state_dict(), f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt")
0178 print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0179
0180 print("✅ 训练完成。最优模型保存在 best_model_*.pt")
0181
0182 out_file.cd()
0183 hist2d.Write()
0184 hist_prof.Write()
0185 h2_p34.Write()
0186 h2_p56.Write()
0187 h2_calo.Write()
0188 h2_SiCalo.Write()
0189 out_file.Close()
0190
0191 if __name__ == "__main__":
0192 list_file = "../train500k.list"
0193 pt_bins = [(0, 10)]
0194 for pt_min, pt_max in pt_bins:
0195 train(list_file, pt_min=pt_min, pt_max=pt_max)