Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data_pointnet import TrackPointNetDataset
0005 from model_pointnet import PointNetRegressor
0006 import os
0007 
0008 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=512, epochs=100, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0009     print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0010     print(f"Using device: {device}")
0011 
0012     dataset = TrackPointNetDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0013     train_size = int((1 - val_ratio) * len(dataset))
0014     val_size = len(dataset) - train_size
0015     train_set, val_set = random_split(dataset, [train_size, val_size])
0016 
0017     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018     val_loader = DataLoader(val_set, batch_size=batch_size)
0019 
0020     model = PointNetRegressor().to(device)
0021     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0022     best_val_loss = float("inf")
0023 
0024     for epoch in range(epochs):
0025         model.train()
0026         train_loss = 0
0027         for xb, yb in train_loader:
0028             xb, yb = xb.to(device), yb.to(device)
0029             pred = model(xb)
0030 
0031             pt_reso = (yb - pred) / yb
0032             weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0033             loss = ((pt_reso) ** 2 * weights).mean()
0034 
0035             optimizer.zero_grad()
0036             loss.backward()
0037             optimizer.step()
0038             train_loss += loss.item() * xb.size(0)
0039         train_loss /= len(train_loader.dataset)
0040 
0041         model.eval()
0042         val_loss = 0
0043         with torch.no_grad():
0044             for xb, yb in val_loader:
0045                 xb, yb = xb.to(device), yb.to(device)
0046                 pred = model(xb)
0047                 pt_reso = (yb - pred) / yb
0048                 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0049                 loss = ((pt_reso) ** 2 * weights).mean()
0050                 val_loss += loss.item() * xb.size(0)
0051         val_loss /= len(val_loader.dataset)
0052 
0053         print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0054 
0055         if val_loss < best_val_loss:
0056             best_val_loss = val_loss
0057             os.makedirs("model_weight", exist_ok=True)
0058             torch.save(model.state_dict(), f"model_weight/best_pointnet_pt_{pt_min:.1f}_{pt_max:.1f}.pt")
0059             print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0060 
0061     print("✅ 训练完成。最优模型保存在 model_weight/ 文件夹。")
0062 
0063 if __name__ == "__main__":
0064     list_file = "../train50k.list"
0065     pt_bins = [(0, 10)]
0066     for pt_min, pt_max in pt_bins:
0067         train(list_file, pt_min=pt_min, pt_max=pt_max)