Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import numpy as np
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 from data import TrackCaloDataset
0006 from model import TrackCaloRegressor
0007 import os
0008 import joblib
0009 import ROOT
0010 import matplotlib.pyplot as plt
0011 
0012 def monotonic_loss(y_sorted):
0013     dy = y_sorted[1:] - y_sorted[:-1]
0014     # penalty = torch.relu(dy).sum()
0015     penalty = torch.relu(-dy).sum()
0016     return penalty
0017 
0018 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=1024, epochs=500, lr=5e-5, val_ratio=0.2, device="cuda" if torch.cuda.is_available() else "cpu"):
0019     print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0020     print(f"Using device: {device}")
0021 
0022     out_file = ROOT.TFile("outputFile/proxy_vs_pt.root", "RECREATE")
0023     hist2d = ROOT.TH2D("h2_proxy_vs_pt", "proxy vs pt;proxy;pt", 510, 2, 200, 100, 0, 10)
0024     hist_prof = ROOT.TProfile("h_profile_proxy_vs_pt", "proxy vs pt;proxy;pt", 510, -0.1, 0.5)
0025     h2_p34 = ROOT.TH2D("h2_p34", "XY of p34;X (cm);Y (cm)", 300, -30, 30, 300, -30, 30)
0026     h2_p56 = ROOT.TH2D("h2_p56", "XY of p56;X (cm);Y (cm)", 300, -30, 30, 300, -30, 30)
0027     h2_calo = ROOT.TH2D("h2_calo", "XY of calo;X (cm);Y (cm)", 600, -150, 150, 600, -150, 150)
0028     h2_SiCalo = ROOT.TH2D("h2_SiCalo", "XY of calo;X (cm);Y (cm)", 600, -150, 150, 600, -150, 150)
0029 
0030     dataset = TrackCaloDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0031     train_size = int((1 - val_ratio) * len(dataset))
0032     val_size = len(dataset) - train_size
0033     train_set, val_set = random_split(dataset, [train_size, val_size])
0034 
0035     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0036     val_loader = DataLoader(val_set, batch_size=batch_size)
0037 
0038     model = TrackCaloRegressor().to(device)
0039     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0040     criterion = nn.MSELoss()
0041 
0042     best_val_loss = float("inf")
0043 
0044     lambda_mono = 0.3
0045     # lambda_boundary = 0.1
0046 
0047     max_draw_event = 10
0048     drawn_event_count = 0
0049 
0050     for epoch in range(epochs):
0051         # === train ===
0052         model.train()
0053         train_loss = 0
0054 
0055         for xb, yb in train_loader:
0056             xb, yb = xb.to(device), yb.to(device)
0057             pred = model(xb)
0058 
0059             pt_reso = (pred - yb) / (yb)
0060             weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0061             main_loss = ((pt_reso) ** 2 * weights).mean()
0062 
0063             # === boundary loss ===
0064             x1 = np.array([0, 0.5, 1, 2, 10, 15, 25, 50, 100, 200])
0065             x2 = np.zeros_like(x1)    
0066             # x_boundary_np = np.stack([x1], axis=1)
0067             x_boundary_np = np.stack([x1, x2], axis=1)
0068             x_boundary = torch.tensor(x_boundary_np, dtype=torch.float32).to(device)
0069             y_boundary_target = torch.tensor([0, 0.0961, 0.1922, 0.3844, 1.922, 2.883, 4.805, 9.61, 19.22, 38.44], dtype=torch.float32).unsqueeze(1).to(device)
0070 
0071             y_boundary_pred = model(x_boundary)
0072             boundary_loss = nn.MSELoss()(y_boundary_pred, y_boundary_target)
0073 
0074             lambda_boundary = min(0.005 * epoch, 0.2)
0075 
0076             # === monotonic penalty ===
0077             x_sorted, indices = torch.sort(xb[:,0])
0078             pred_sorted = pred[indices]
0079             mono_penalty = monotonic_loss(pred_sorted)
0080 
0081             # === 总 loss ===
0082             loss = main_loss + lambda_mono * mono_penalty + lambda_boundary * boundary_loss
0083 
0084             optimizer.zero_grad()
0085             loss.backward()
0086             optimizer.step()
0087             train_loss += loss.item() * xb.size(0)
0088 
0089             # === fill ROOT hist + 画图 ===
0090             # for i in range(xb.shape[0]):
0091             #     proxy = xb[i][0].item()
0092             #     pt = yb[i][0].item()
0093             #     dphi_recovered = proxy  # 你没做特征变换
0094 
0095             #     hist2d.Fill(proxy, pt)
0096             #     hist_prof.Fill(dphi_recovered, pt)
0097 
0098             #     x34, y34 = xb[i][0].item(), xb[i][1].item()
0099             #     x56, y56 = xb[i][2].item(), xb[i][3].item()
0100             #     xcalo, ycalo = xb[i][4].item(), xb[i][5].item()
0101             #     h2_p34.Fill(x34, y34)
0102             #     h2_p56.Fill(x56, y56)
0103             #     h2_calo.Fill(xcalo, ycalo)
0104 
0105             #     if drawn_event_count < max_draw_event:
0106             #         plt.figure(figsize=(6, 6))
0107             #         plt.scatter([x34], [y34], c='red', label='p34', s=10)
0108             #         plt.scatter([x56], [y56], c='green', label='p56', s=10)
0109             #         plt.scatter([xcalo], [ycalo], c='blue', label='calo', s=10)
0110 
0111             #         dx = x56 - x34
0112             #         dy = y56 - y34
0113             #         if dx != 0:
0114             #             k = dy / dx
0115             #             b = y34 - k * x34
0116             #             x_left = -120
0117             #             x_right = 120
0118             #             y_left = k * x_left + b
0119             #             y_right = k * x_right + b
0120             #             plt.plot([x_left, x_right], [y_left, y_right], color='black', linestyle='-', linewidth=1, label='p34->p56 extended')
0121             #         else:
0122             #             plt.plot([x34, x34], [-120, 120], color='black', linestyle='-', linewidth=1, label='p34->p56 extended')
0123 
0124             #         plt.plot([x56, xcalo], [y56, ycalo], color='purple', linestyle='--', linewidth=1)
0125 
0126             #         plt.xlim(-120, 120)
0127             #         plt.ylim(-120, 120)
0128             #         plt.xlabel("X (cm)")
0129             #         plt.ylabel("Y (cm)")
0130             #         plt.title(f"Train Event {drawn_event_count+1}, Pt = {pt:.2f} GeV")
0131             #         plt.legend()
0132             #         plt.grid(True)
0133             #         plt.savefig(f"event_images/train_event_{drawn_event_count+1:03d}.png")
0134             #         plt.close()
0135 
0136             #         drawn_event_count += 1
0137 
0138         train_loss /= len(train_loader.dataset)
0139 
0140         # === val ===
0141         model.eval()
0142         val_loss = 0
0143         with torch.no_grad():
0144             for xb, yb in val_loader:
0145                 xb, yb = xb.to(device), yb.to(device)
0146                 pred = model(xb)
0147                 pt_reso = (pred - yb) / (yb)
0148                 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0149                 main_loss = ((pt_reso) ** 2 * weights).mean()
0150 
0151                 x1 = np.array([0, 0.5, 1, 2, 10, 15, 25, 50, 100, 200])
0152                 x2 = np.zeros_like(x1)    
0153                 # x_boundary_np = np.stack([x1], axis=1)
0154                 x_boundary_np = np.stack([x1, x2], axis=1)
0155                 x_boundary = torch.tensor(x_boundary_np, dtype=torch.float32).to(device)
0156                 y_boundary_target = torch.tensor([0, 0.0961, 0.1922, 0.3844, 1.922, 2.883, 4.805, 9.61, 19.22, 38.44], dtype=torch.float32).unsqueeze(1).to(device)
0157 
0158                 y_boundary_pred = model(x_boundary)
0159                 boundary_loss = nn.MSELoss()(y_boundary_pred, y_boundary_target)
0160 
0161                 lambda_boundary = min(0.005 * epoch, 0.2)
0162 
0163                 x_sorted, indices = torch.sort(xb[:,0])
0164                 pred_sorted = pred[indices]
0165                 mono_penalty = monotonic_loss(pred_sorted)
0166 
0167                 loss = main_loss + lambda_mono * mono_penalty + lambda_boundary * boundary_loss
0168 
0169                 val_loss += loss.item() * xb.size(0)
0170 
0171         val_loss /= len(val_loader.dataset)
0172 
0173         print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0174 
0175         if epoch >= 200 and val_loss < best_val_loss:
0176             best_val_loss = val_loss
0177             torch.save(model.state_dict(), f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt")
0178             print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0179 
0180     print("✅ 训练完成。最优模型保存在 best_model_*.pt")
0181 
0182     out_file.cd()
0183     hist2d.Write()
0184     hist_prof.Write()
0185     h2_p34.Write()
0186     h2_p56.Write()
0187     h2_calo.Write()
0188     h2_SiCalo.Write()
0189     out_file.Close()
0190 
0191 if __name__ == "__main__":
0192     list_file = "../train500k.list"
0193     pt_bins = [(0, 10)]
0194     for pt_min, pt_max in pt_bins:
0195         train(list_file, pt_min=pt_min, pt_max=pt_max)