Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import matplotlib.pyplot as plt
0002 import torch
0003 import numpy as np
0004 from data import TrackCaloDataset
0005 from model import TrackCaloRegressor
0006 import joblib
0007 from sklearn.preprocessing import StandardScaler
0008 
0009 # === 加载数据和模型 ===
0010 scaler = joblib.load("scaler.pkl")  # 确保和训练保存的路径一致
0011 
0012 dataset = TrackCaloDataset("../test50k.list", scaler=scaler)
0013 loader = torch.utils.data.DataLoader(dataset, batch_size=128)
0014 device = "cuda" if torch.cuda.is_available() else "cpu"
0015 
0016 model = TrackCaloRegressor()
0017 model.load_state_dict(torch.load("best_model.pt"))
0018 model.to(device)
0019 model.eval()
0020 
0021 # === 收集所有预测和真实值 ===
0022 all_preds = []
0023 all_truths = []
0024 
0025 with torch.no_grad():
0026     for x, y in loader:
0027         x = x.to(device)
0028         y = y.to(device)
0029         
0030         # pred = model(x)
0031         pred_log = model(x)
0032         pred = torch.exp(pred_log)  # 恢复 pt
0033 
0034         all_preds.append(pred.cpu().numpy())
0035         all_truths.append(y.cpu().numpy())
0036 
0037 # === 合并结果 ===
0038 all_preds = np.concatenate(all_preds).flatten()
0039 all_truths = np.concatenate(all_truths).flatten()
0040 
0041 # === 画图: truth vs pred ===
0042 plt.figure(figsize=(6, 6))
0043 plt.hist2d(all_truths, all_preds, bins=200, range=[[0, 15], [0, 15]], cmap='viridis')
0044 plt.plot([0, 20], [0, 20], color='red', linestyle='--', label='Ideal')  # y=x 参考线
0045 plt.xlabel("Truth pt")
0046 plt.ylabel("Predicted pt")
0047 plt.title("Truth vs Predicted pt")
0048 plt.colorbar(label='Counts')
0049 plt.legend()
0050 plt.grid(True)
0051 plt.tight_layout()
0052 plt.savefig("truth_vs_pred.png", dpi=300)
0053 plt.show()