Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data import TrackCaloDataset
0005 from model import TrackCaloRegressor
0006 import joblib
0007 import os
0008 
0009 def train(list_file, batch_size=1024, epochs=200, lr=5e-5, val_ratio=0.2, device="cuda" if torch.cuda.is_available() else "cpu"):
0010     print(f"Using device: {device}")
0011 
0012     dataset = TrackCaloDataset(list_file)
0013     train_size = int((1 - val_ratio) * len(dataset))
0014     val_size = len(dataset) - train_size
0015     train_set, val_set = random_split(dataset, [train_size, val_size])
0016 
0017     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018     val_loader = DataLoader(val_set, batch_size=batch_size)
0019 
0020     model = TrackCaloRegressor().to(device)
0021     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
0022     criterion = nn.MSELoss()
0023 
0024     best_val_loss = float("inf")
0025 
0026     for epoch in range(epochs):
0027         model.train()
0028         train_loss = 0
0029         for xb, yb in train_loader:
0030             xb, yb = xb.to(device), yb.to(device)
0031             yb_log = torch.log(yb + 1e-6)  # log(pt) target
0032             pred_log = model(xb)
0033             loss = criterion(pred_log, yb_log)
0034 
0035             optimizer.zero_grad()
0036             loss.backward()
0037             optimizer.step()
0038             train_loss += loss.item() * xb.size(0)
0039         train_loss /= len(train_loader.dataset)
0040 
0041         model.eval()
0042         val_loss = 0
0043         with torch.no_grad():
0044             for xb, yb in val_loader:
0045                 xb, yb = xb.to(device), yb.to(device)
0046                 yb_log = torch.log(yb + 1e-6)
0047                 pred_log = model(xb)
0048                 loss = criterion(pred_log, yb_log)
0049                 val_loss += loss.item() * xb.size(0)
0050         val_loss /= len(val_loader.dataset)
0051 
0052         print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}")
0053 
0054         if val_loss < best_val_loss:
0055             best_val_loss = val_loss
0056             torch.save(model.state_dict(), "best_model.pt")
0057             print(f"✓ Saved best model (val loss = {val_loss:.5f})")
0058 
0059     torch.save(model.state_dict(), "final_model.pt")
0060     print("✅ 模型训练完成,保存 best_model.pt 和 final_model.pt")
0061 
0062     # 保存标准化器
0063     joblib.dump(train_set.dataset.scaler, "scaler.pkl")
0064     print("✅ scaler.pkl 标准化器已保存")
0065 
0066 if __name__ == "__main__":
0067     list_file = "../train50k.list"
0068     if not os.path.exists(list_file):
0069         print(f"❌ 找不到 {list_file}")
0070     else:
0071         train(list_file)