Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:05

0001 ### train_combined.py
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 import ROOT
0006 
0007 from data_combined import FusionDataset
0008 from model_combined import FusionRegressor
0009 
0010 
0011 def train_fusion(
0012     list_file,
0013     pt_min=0.0,
0014     pt_max=10.0,
0015     batch_size=512,
0016     epochs=1,
0017     lr=5e-5,
0018     weight_decay=1e-5,
0019     hidden_dim=32,
0020     device="cuda" if torch.cuda.is_available() else "cpu"
0021 ):
0022     print(f"=== Training FusionRegressor [{pt_min},{pt_max}) GeV on {device} ===")
0023 
0024     # —— 准备数据集并划分 —— #
0025     dataset     = FusionDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0026     train_size  = int(0.7 * len(dataset))
0027     val_size    = len(dataset) - train_size
0028     train_set, val_set = random_split(dataset, [train_size, val_size])
0029 
0030     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0031     val_loader   = DataLoader(val_set,   batch_size=batch_size)
0032 
0033     # —— 模型与优化器 —— #
0034     model     = FusionRegressor(input_dim=2, hidden_dim=hidden_dim).to(device)
0035     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
0036     criterion = nn.MSELoss(reduction="none")
0037 
0038     best_val = float('inf')
0039     for epoch in range(1, epochs+1):
0040         model.train()
0041         train_loss = 0
0042         for xb, yb in train_loader:
0043             xb, yb = xb.to(device), yb.to(device)
0044             pred = model(xb)
0045 
0046             pt_reso = (pred - yb) / (yb)
0047             weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0048             loss = ((pt_reso) ** 2 * weights).mean()
0049             # loss = (pt_reso ** 2).mean()
0050 
0051             optimizer.zero_grad()
0052             loss.backward()
0053             optimizer.step()
0054             train_loss += loss.item() * xb.size(0)
0055         train_loss /= len(train_loader.dataset)
0056 
0057         model.eval()
0058         val_loss = 0.0
0059         with torch.no_grad():
0060             for xb, yb in val_loader:
0061                 xb, yb = xb.to(device), yb.to(device)
0062                 pred = model(xb)
0063                 pt_reso = (pred - yb) / (yb)
0064                 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0065                 loss = ((pt_reso) ** 2 * weights).mean()
0066 
0067                 val_loss += loss.item() * xb.size(0)
0068 
0069         val_loss /= len(val_loader.dataset)
0070         
0071         print(f"Epoch {epoch:03d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
0072 
0073         if val_loss < best_val:
0074             best_val = val_loss
0075             torch.save(model.state_dict(), "model_weight/best_model_combined.pt")
0076             print(f"✓ Saved best model (Val {val_loss:.4f})")
0077 
0078     print("✅ Training complete.")
0079 
0080     # —— 在训练结束后,用 dataset.X (pred1,pred2) 和 dataset.Y (truth) 填充两个 TH2D —— #
0081     # 创建直方图
0082     nbins_pt = int((pt_max - pt_min) * 10)
0083     h2_dphi = ROOT.TH2D(
0084         "h2_dphi_diff",
0085         "dphi-model: truth pt vs (truth-pred);truth pt [GeV];truth-pred [GeV]",
0086         nbins_pt, pt_min, pt_max, 200, -2.0, 2.0
0087     )
0088     h2_energy = ROOT.TH2D(
0089         "h2_energy_diff",
0090         "energy-model: truth pt vs (truth-pred);truth pt [GeV];truth-pred [GeV]",
0091         nbins_pt, pt_min, pt_max, 200, -2.0, 2.0
0092     )
0093 
0094     # 从 FusionDataset 中获取预测和真实值填图,并打印前几条用于调试
0095     for i in range(len(dataset.X)):
0096         truth = dataset.Y[i].item()
0097         pred1 = dataset.X[i, 0].item()
0098         pred2 = dataset.X[i, 1].item()
0099         # 打印前10条样本的预测和真实值
0100         if i < 2:
0101             print(f"[Sample {i}] pred1={pred1:.4f}, pred2={pred2:.4f}, truth={truth:.4f}")
0102         # 填充相对误差直方图
0103         h2_dphi.Fill(truth, (pred1 - truth) / truth)
0104         h2_energy.Fill(truth, (pred2 - truth) / truth)
0105 
0106     # 写入 ROOT 文件
0107     fout = ROOT.TFile("outputFile/submodel_performance.root", "RECREATE")
0108     h2_dphi.Write()
0109     h2_energy.Write()
0110     fout.Close()
0111     print("✅ Sub-model histograms saved to submodel_performance.root")
0112 
0113 if __name__ == "__main__":
0114     train_fusion(
0115         "../train500k.list",
0116         pt_min=0.0,
0117         pt_max=10.0,
0118         batch_size=512,
0119         epochs=100,
0120         lr=5e-5,
0121         weight_decay=1e-5,
0122         hidden_dim=256
0123     )