Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data import TrackCaloDataset
0005 from model import TrackCaloRegressor
0006 import os
0007 import joblib
0008 
0009 def train(list_file, batch_size=1024, epochs=200, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0010     print(f"Using device: {device}")
0011 
0012     dataset = TrackCaloDataset(list_file)
0013     train_size = int((1 - val_ratio) * len(dataset))
0014     val_size = len(dataset) - train_size
0015     train_set, val_set = random_split(dataset, [train_size, val_size])
0016 
0017     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018     val_loader = DataLoader(val_set, batch_size=batch_size)
0019 
0020     model = TrackCaloRegressor().to(device)
0021     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
0022     criterion = nn.MSELoss()
0023 
0024     best_val_loss = float("inf")
0025 
0026     for epoch in range(epochs):
0027         model.train()
0028         train_loss = 0
0029         for xb, yb in train_loader:
0030             xb, yb = xb.to(device), yb.to(device)
0031             pred = model(xb)
0032 
0033             # loss function
0034             # pt_reso = (pred - yb) / (yb + 1e-3)
0035             # pt_weights = torch.clamp(1.0 / (yb + 1e-3), max=10.0)
0036             # loss = (pt_reso**2 * pt_weights).mean()
0037 
0038             pt_reso = (yb - pred) / (yb + 1e-6)
0039             weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0  # 主峰权重变为 3.0,其它为 1.0
0040             loss = ((pt_reso) ** 2 * weights).mean()
0041 
0042             optimizer.zero_grad()
0043             loss.backward()
0044             optimizer.step()
0045             train_loss += loss.item() * xb.size(0)
0046         train_loss /= len(train_loader.dataset)
0047 
0048         model.eval()
0049         val_loss = 0
0050         with torch.no_grad():
0051             for xb, yb in val_loader:
0052                 xb, yb = xb.to(device), yb.to(device)
0053                 pred = model(xb)
0054 
0055                 # pt_reso = (pred - yb) / (yb + 1e-3)
0056                 # pt_weights = torch.clamp(1.0 / (yb + 1e-3), max=10.0)
0057                 # loss = (pt_reso**2 * pt_weights).mean()
0058 
0059                 pt_reso = (yb - pred) / (yb + 1e-6)
0060                 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0  # 主峰权重变为 3.0,其它为 1.0
0061                 loss = ((pt_reso) ** 2 * weights).mean()
0062 
0063                 val_loss += loss.item() * xb.size(0)
0064         val_loss /= len(val_loader.dataset)
0065 
0066         print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0067 
0068         if val_loss < best_val_loss:
0069             best_val_loss = val_loss
0070             torch.save(model.state_dict(), "best_model.pt")
0071             print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0072 
0073     torch.save(model.state_dict(), "final_model.pt")
0074     print("✅ 训练完成。最优模型保存在 best_model.pt, and final_model.pt")
0075 
0076     # === 保存 scaler(标准化器) ===
0077     joblib.dump(train_set.dataset.scaler, "scaler.pkl")
0078     print("✅ 标准化器 scaler.pkl 已保存")
0079 
0080 if __name__ == "__main__":
0081     list_file = "../train500k.list"
0082     if not os.path.exists(list_file):
0083         print(f"❌ 找不到 {list_file}")
0084     else:
0085         train(list_file)