File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data import TrackCaloDataset
0005 from model import TrackCaloRegressor
0006 import os
0007 import joblib
0008
0009 def train(list_file, batch_size=1024, epochs=200, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0010 print(f"Using device: {device}")
0011
0012 dataset = TrackCaloDataset(list_file)
0013 train_size = int((1 - val_ratio) * len(dataset))
0014 val_size = len(dataset) - train_size
0015 train_set, val_set = random_split(dataset, [train_size, val_size])
0016
0017 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018 val_loader = DataLoader(val_set, batch_size=batch_size)
0019
0020 model = TrackCaloRegressor().to(device)
0021 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
0022 criterion = nn.MSELoss()
0023
0024 best_val_loss = float("inf")
0025
0026 for epoch in range(epochs):
0027 model.train()
0028 train_loss = 0
0029 for xb, yb in train_loader:
0030 xb, yb = xb.to(device), yb.to(device)
0031 pred = model(xb)
0032
0033
0034
0035
0036
0037
0038 pt_reso = (yb - pred) / (yb + 1e-6)
0039 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0040 loss = ((pt_reso) ** 2 * weights).mean()
0041
0042 optimizer.zero_grad()
0043 loss.backward()
0044 optimizer.step()
0045 train_loss += loss.item() * xb.size(0)
0046 train_loss /= len(train_loader.dataset)
0047
0048 model.eval()
0049 val_loss = 0
0050 with torch.no_grad():
0051 for xb, yb in val_loader:
0052 xb, yb = xb.to(device), yb.to(device)
0053 pred = model(xb)
0054
0055
0056
0057
0058
0059 pt_reso = (yb - pred) / (yb + 1e-6)
0060 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0061 loss = ((pt_reso) ** 2 * weights).mean()
0062
0063 val_loss += loss.item() * xb.size(0)
0064 val_loss /= len(val_loader.dataset)
0065
0066 print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0067
0068 if val_loss < best_val_loss:
0069 best_val_loss = val_loss
0070 torch.save(model.state_dict(), "best_model.pt")
0071 print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0072
0073 torch.save(model.state_dict(), "final_model.pt")
0074 print("✅ 训练完成。最优模型保存在 best_model.pt, and final_model.pt")
0075
0076
0077 joblib.dump(train_set.dataset.scaler, "scaler.pkl")
0078 print("✅ 标准化器 scaler.pkl 已保存")
0079
0080 if __name__ == "__main__":
0081 list_file = "../train500k.list"
0082 if not os.path.exists(list_file):
0083 print(f"❌ 找不到 {list_file}")
0084 else:
0085 train(list_file)