Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:05

0001 ### model_combined.py
0002 import torch
0003 import torch.nn as nn
0004 import torch.nn.functional as F
0005 
0006 class FusionRegressor(nn.Module):
0007     def __init__(self, input_dim=2, hidden_dim=256):
0008         super().__init__()
0009         self.gate = nn.Sequential(
0010             nn.Linear(input_dim, hidden_dim),
0011             nn.ReLU(),
0012             nn.Linear(hidden_dim, hidden_dim),
0013             nn.ReLU(),
0014             nn.Linear(hidden_dim, hidden_dim),
0015             nn.ReLU(),
0016             nn.Linear(hidden_dim, 2)
0017         )
0018 
0019     def forward(self, x):
0020         logits  = self.gate(x)
0021         weights = F.softmax(logits, dim=1)
0022         return weights[:,0:1] * x[:,0:1] + weights[:,1:2] * x[:,1:2]