Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 # === train.py ===
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 from data import TrackCaloDataset
0006 from model import TrackCaloRegressor
0007 import os
0008 import joblib
0009 
0010 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=1024, epochs=300, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0011     print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0012     print(f"Using device: {device}")
0013 
0014     dataset = TrackCaloDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0015     train_size = int((1 - val_ratio) * len(dataset))
0016     val_size = len(dataset) - train_size
0017     train_set, val_set = random_split(dataset, [train_size, val_size])
0018 
0019     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0020     val_loader = DataLoader(val_set, batch_size=batch_size)
0021 
0022     model = TrackCaloRegressor().to(device)
0023     # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
0024     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0025 
0026     best_val_loss = float("inf")
0027 
0028     for epoch in range(epochs):
0029         model.train()
0030         train_loss = 0
0031         for xb, yb in train_loader:
0032             xb, yb = xb.to(device), yb.to(device)
0033             pred = model(xb)
0034 
0035             pt_reso = (pred - yb) / (yb)
0036             weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0037             loss = ((pt_reso) ** 2 * weights).mean()
0038             # loss = (pt_reso ** 2).mean()
0039 
0040             optimizer.zero_grad()
0041             loss.backward()
0042             optimizer.step()
0043             train_loss += loss.item() * xb.size(0)
0044         train_loss /= len(train_loader.dataset)
0045 
0046         model.eval()
0047         val_loss = 0
0048         with torch.no_grad():
0049             for xb, yb in val_loader:
0050                 xb, yb = xb.to(device), yb.to(device)
0051                 pred = model(xb)
0052                 pt_reso = (pred - yb) / (yb)
0053                 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0054                 loss = ((pt_reso) ** 2 * weights).mean()
0055                 # loss = (pt_reso ** 2).mean()
0056                 
0057                 val_loss += loss.item() * xb.size(0)
0058         val_loss /= len(val_loader.dataset)
0059 
0060         print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0061 
0062         if val_loss < best_val_loss:
0063             best_val_loss = val_loss
0064             torch.save(model.state_dict(), f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt")
0065             print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0066 
0067     print("✅ 训练完成。最优模型保存在 best_model_*.pt")
0068 
0069     joblib.dump(train_set.dataset.scaler, f"model_weight/scaler_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pkl")
0070     print("✅ 标准化器已保存")
0071 
0072 if __name__ == "__main__":
0073     list_file = "../train500k.list"
0074     # pt_bins = [(0, 2), (2, 4), (4, 6), (6, 8), (8, 10)]
0075     pt_bins = [(0, 10)]
0076     for pt_min, pt_max in pt_bins:
0077         train(list_file, pt_min=pt_min, pt_max=pt_max)