File indexing completed on 2025-12-16 09:18:05
0001
0002 import torch
0003 import torch.nn as nn
0004 import torch.nn.functional as F
0005
0006 class FusionRegressor(nn.Module):
0007 def __init__(self, input_dim=2, hidden_dim=256):
0008 super().__init__()
0009 self.gate = nn.Sequential(
0010 nn.Linear(input_dim, hidden_dim),
0011 nn.ReLU(),
0012 nn.Linear(hidden_dim, hidden_dim),
0013 nn.ReLU(),
0014 nn.Linear(hidden_dim, hidden_dim),
0015 nn.ReLU(),
0016 nn.Linear(hidden_dim, 2)
0017 )
0018
0019 def forward(self, x):
0020 logits = self.gate(x)
0021 weights = F.softmax(logits, dim=1)
0022 return weights[:,0:1] * x[:,0:1] + weights[:,1:2] * x[:,1:2]