File indexing completed on 2025-12-16 09:18:05
0001
0002 import torch
0003 import torch.nn as nn
0004 from torch.utils.data import DataLoader, random_split
0005 import ROOT
0006
0007 from data_combined import FusionDataset
0008 from model_combined import FusionRegressor
0009
0010
0011 def train_fusion(
0012 list_file,
0013 pt_min=0.0,
0014 pt_max=10.0,
0015 batch_size=512,
0016 epochs=1,
0017 lr=5e-5,
0018 weight_decay=1e-5,
0019 hidden_dim=32,
0020 device="cuda" if torch.cuda.is_available() else "cpu"
0021 ):
0022 print(f"=== Training FusionRegressor [{pt_min},{pt_max}) GeV on {device} ===")
0023
0024
0025 dataset = FusionDataset(list_file, pt_min=pt_min, pt_max=pt_max)
0026 train_size = int(0.7 * len(dataset))
0027 val_size = len(dataset) - train_size
0028 train_set, val_set = random_split(dataset, [train_size, val_size])
0029
0030 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
0031 val_loader = DataLoader(val_set, batch_size=batch_size)
0032
0033
0034 model = FusionRegressor(input_dim=2, hidden_dim=hidden_dim).to(device)
0035 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
0036 criterion = nn.MSELoss(reduction="none")
0037
0038 best_val = float('inf')
0039 for epoch in range(1, epochs+1):
0040 model.train()
0041 train_loss = 0
0042 for xb, yb in train_loader:
0043 xb, yb = xb.to(device), yb.to(device)
0044 pred = model(xb)
0045
0046 pt_reso = (pred - yb) / (yb)
0047 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0048 loss = ((pt_reso) ** 2 * weights).mean()
0049
0050
0051 optimizer.zero_grad()
0052 loss.backward()
0053 optimizer.step()
0054 train_loss += loss.item() * xb.size(0)
0055 train_loss /= len(train_loader.dataset)
0056
0057 model.eval()
0058 val_loss = 0.0
0059 with torch.no_grad():
0060 for xb, yb in val_loader:
0061 xb, yb = xb.to(device), yb.to(device)
0062 pred = model(xb)
0063 pt_reso = (pred - yb) / (yb)
0064 weights = (pt_reso.abs() < 0.2).float() * 2.0 + 1.0
0065 loss = ((pt_reso) ** 2 * weights).mean()
0066
0067 val_loss += loss.item() * xb.size(0)
0068
0069 val_loss /= len(val_loader.dataset)
0070
0071 print(f"Epoch {epoch:03d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
0072
0073 if val_loss < best_val:
0074 best_val = val_loss
0075 torch.save(model.state_dict(), "model_weight/best_model_combined.pt")
0076 print(f"✓ Saved best model (Val {val_loss:.4f})")
0077
0078 print("✅ Training complete.")
0079
0080
0081
0082 nbins_pt = int((pt_max - pt_min) * 10)
0083 h2_dphi = ROOT.TH2D(
0084 "h2_dphi_diff",
0085 "dphi-model: truth pt vs (truth-pred);truth pt [GeV];truth-pred [GeV]",
0086 nbins_pt, pt_min, pt_max, 200, -2.0, 2.0
0087 )
0088 h2_energy = ROOT.TH2D(
0089 "h2_energy_diff",
0090 "energy-model: truth pt vs (truth-pred);truth pt [GeV];truth-pred [GeV]",
0091 nbins_pt, pt_min, pt_max, 200, -2.0, 2.0
0092 )
0093
0094
0095 for i in range(len(dataset.X)):
0096 truth = dataset.Y[i].item()
0097 pred1 = dataset.X[i, 0].item()
0098 pred2 = dataset.X[i, 1].item()
0099
0100 if i < 2:
0101 print(f"[Sample {i}] pred1={pred1:.4f}, pred2={pred2:.4f}, truth={truth:.4f}")
0102
0103 h2_dphi.Fill(truth, (pred1 - truth) / truth)
0104 h2_energy.Fill(truth, (pred2 - truth) / truth)
0105
0106
0107 fout = ROOT.TFile("outputFile/submodel_performance.root", "RECREATE")
0108 h2_dphi.Write()
0109 h2_energy.Write()
0110 fout.Close()
0111 print("✅ Sub-model histograms saved to submodel_performance.root")
0112
0113 if __name__ == "__main__":
0114 train_fusion(
0115 "../train500k.list",
0116 pt_min=0.0,
0117 pt_max=10.0,
0118 batch_size=512,
0119 epochs=100,
0120 lr=5e-5,
0121 weight_decay=1e-5,
0122 hidden_dim=256
0123 )