File indexing completed on 2025-12-16 09:18:06
0001
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 from data import TrackCaloDataset
0006 from model import TrackCaloRegressor
0007 import os
0008 import joblib
0009
0010 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=1024, epochs=300, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0011 print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0012 print(f"Using device: {device}")
0013
0014 dataset = TrackCaloDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0015 train_size = int((1 - val_ratio) * len(dataset))
0016 val_size = len(dataset) - train_size
0017 train_set, val_set = random_split(dataset, [train_size, val_size])
0018
0019 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0020 val_loader = DataLoader(val_set, batch_size=batch_size)
0021
0022 model = TrackCaloRegressor().to(device)
0023
0024 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0025
0026 best_val_loss = float("inf")
0027
0028 for epoch in range(epochs):
0029 model.train()
0030 train_loss = 0
0031 for xb, yb in train_loader:
0032 xb, yb = xb.to(device), yb.to(device)
0033 pred = model(xb)
0034
0035 pt_reso = (pred - yb) / (yb)
0036 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0037 loss = ((pt_reso) ** 2 * weights).mean()
0038
0039
0040 optimizer.zero_grad()
0041 loss.backward()
0042 optimizer.step()
0043 train_loss += loss.item() * xb.size(0)
0044 train_loss /= len(train_loader.dataset)
0045
0046 model.eval()
0047 val_loss = 0
0048 with torch.no_grad():
0049 for xb, yb in val_loader:
0050 xb, yb = xb.to(device), yb.to(device)
0051 pred = model(xb)
0052 pt_reso = (pred - yb) / (yb)
0053 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0054 loss = ((pt_reso) ** 2 * weights).mean()
0055
0056
0057 val_loss += loss.item() * xb.size(0)
0058 val_loss /= len(val_loader.dataset)
0059
0060 print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0061
0062 if val_loss < best_val_loss:
0063 best_val_loss = val_loss
0064 torch.save(model.state_dict(), f"model_weight/best_model_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pt")
0065 print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0066
0067 print("✅ 训练完成。最优模型保存在 best_model_*.pt")
0068
0069 joblib.dump(train_set.dataset.scaler, f"model_weight/scaler_pt_{pt_min:.1f}_{pt_max:.1f}_INTT_CaloIwoE.pkl")
0070 print("✅ 标准化器已保存")
0071
0072 if __name__ == "__main__":
0073 list_file = "../train500k.list"
0074
0075 pt_bins = [(0, 10)]
0076 for pt_min, pt_max in pt_bins:
0077 train(list_file, pt_min=pt_min, pt_max=pt_max)