File indexing completed on 2025-12-17 09:18:16
0001 import torch
0002 import torch.nn as nn
0003
0004 class TrackCaloRegressor(nn.Module):
0005 def __init__(self, input_dim=7, hidden_dim=256):
0006 super().__init__()
0007 self.net = nn.Sequential(
0008 nn.Linear(input_dim, hidden_dim),
0009 nn.ReLU(),
0010
0011 nn.Linear(hidden_dim, hidden_dim),
0012 nn.ReLU(),
0013
0014 nn.Linear(hidden_dim, hidden_dim),
0015 nn.ReLU(),
0016
0017 nn.Linear(hidden_dim, 1)
0018 )
0019
0020 def forward(self, x):
0021 return self.net(x)