File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import torch.nn as nn
0003 from torch.utils.data import DataLoader, random_split
0004 from data_pointnet import TrackPointNetDataset
0005 from model_pointnet import PointNetRegressor
0006 import os
0007
0008 def train(list_file, pt_min=0.0, pt_max=2.0, batch_size=512, epochs=100, lr=5e-5, val_ratio=0.3, device="cuda" if torch.cuda.is_available() else "cpu"):
0009 print(f"Training pt range: [{pt_min}, {pt_max}) GeV")
0010 print(f"Using device: {device}")
0011
0012 dataset = TrackPointNetDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0013 train_size = int((1 - val_ratio) * len(dataset))
0014 val_size = len(dataset) - train_size
0015 train_set, val_set = random_split(dataset, [train_size, val_size])
0016
0017 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0018 val_loader = DataLoader(val_set, batch_size=batch_size)
0019
0020 model = PointNetRegressor().to(device)
0021 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
0022 best_val_loss = float("inf")
0023
0024 for epoch in range(epochs):
0025 model.train()
0026 train_loss = 0
0027 for xb, yb in train_loader:
0028 xb, yb = xb.to(device), yb.to(device)
0029 pred = model(xb)
0030
0031 pt_reso = (yb - pred) / yb
0032 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0033 loss = ((pt_reso) ** 2 * weights).mean()
0034
0035 optimizer.zero_grad()
0036 loss.backward()
0037 optimizer.step()
0038 train_loss += loss.item() * xb.size(0)
0039 train_loss /= len(train_loader.dataset)
0040
0041 model.eval()
0042 val_loss = 0
0043 with torch.no_grad():
0044 for xb, yb in val_loader:
0045 xb, yb = xb.to(device), yb.to(device)
0046 pred = model(xb)
0047 pt_reso = (yb - pred) / yb
0048 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0049 loss = ((pt_reso) ** 2 * weights).mean()
0050 val_loss += loss.item() * xb.size(0)
0051 val_loss /= len(val_loader.dataset)
0052
0053 print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
0054
0055 if val_loss < best_val_loss:
0056 best_val_loss = val_loss
0057 os.makedirs("model_weight", exist_ok=True)
0058 torch.save(model.state_dict(), f"model_weight/best_pointnet_pt_{pt_min:.1f}_{pt_max:.1f}.pt")
0059 print(f"✓ Saved best model (val loss = {val_loss:.4f})")
0060
0061 print("✅ 训练完成。最优模型保存在 model_weight/ 文件夹。")
0062
0063 if __name__ == "__main__":
0064 list_file = "../train50k.list"
0065 pt_bins = [(0, 10)]
0066 for pt_min, pt_max in pt_bins:
0067 train(list_file, pt_min=pt_min, pt_max=pt_max)