File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data import TrackCaloDataset
0005 from model import TrackCaloRegressor
0006 import joblib
0007 import os
0008
0009 def train(list_file, batch_size=1024, epochs=200, lr=5e-5, val_ratio=0.2, device="cuda" if torch.cuda.is_available() else "cpu"):
0010 print(f"Using device: {device}")
0011
0012 dataset = TrackCaloDataset(list_file)
0013 train_size = int((1 - val_ratio) * len(dataset))
0014 val_size = len(dataset) - train_size
0015 train_set, val_set = random_split(dataset, [train_size, val_size])
0016
0017 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018 val_loader = DataLoader(val_set, batch_size=batch_size)
0019
0020 model = TrackCaloRegressor().to(device)
0021 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
0022 criterion = nn.MSELoss()
0023
0024 best_val_loss = float("inf")
0025
0026 for epoch in range(epochs):
0027 model.train()
0028 train_loss = 0
0029 for xb, yb in train_loader:
0030 xb, yb = xb.to(device), yb.to(device)
0031 yb_log = torch.log(yb + 1e-6)
0032 pred_log = model(xb)
0033 loss = criterion(pred_log, yb_log)
0034
0035 optimizer.zero_grad()
0036 loss.backward()
0037 optimizer.step()
0038 train_loss += loss.item() * xb.size(0)
0039 train_loss /= len(train_loader.dataset)
0040
0041 model.eval()
0042 val_loss = 0
0043 with torch.no_grad():
0044 for xb, yb in val_loader:
0045 xb, yb = xb.to(device), yb.to(device)
0046 yb_log = torch.log(yb + 1e-6)
0047 pred_log = model(xb)
0048 loss = criterion(pred_log, yb_log)
0049 val_loss += loss.item() * xb.size(0)
0050 val_loss /= len(val_loader.dataset)
0051
0052 print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}")
0053
0054 if val_loss < best_val_loss:
0055 best_val_loss = val_loss
0056 torch.save(model.state_dict(), "best_model.pt")
0057 print(f"✓ Saved best model (val loss = {val_loss:.5f})")
0058
0059 torch.save(model.state_dict(), "final_model.pt")
0060 print("✅ 模型训练完成,保存 best_model.pt 和 final_model.pt")
0061
0062
0063 joblib.dump(train_set.dataset.scaler, "scaler.pkl")
0064 print("✅ scaler.pkl 标准化器已保存")
0065
0066 if __name__ == "__main__":
0067 list_file = "../train50k.list"
0068 if not os.path.exists(list_file):
0069 print(f"❌ 找不到 {list_file}")
0070 else:
0071 train(list_file)