Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:18:16

0001 import torch
0002 import torch.nn as nn
0003 
0004 class TrackCaloRegressor(nn.Module):
0005     def __init__(self, input_dim=7, hidden_dim=256):
0006         super().__init__()
0007         self.net = nn.Sequential(
0008             nn.Linear(input_dim, hidden_dim),
0009             nn.ReLU(),
0010             # nn.Dropout(0.2),
0011             nn.Linear(hidden_dim, hidden_dim),
0012             nn.ReLU(),
0013             # nn.Dropout(0.2),
0014             nn.Linear(hidden_dim, hidden_dim),
0015             nn.ReLU(),
0016             # nn.Dropout(0.2),
0017             nn.Linear(hidden_dim, 1)  # 回归输出
0018         )
0019 
0020     def forward(self, x):
0021         return self.net(x)