Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 import torch
0002 import torch.nn as nn
0003 
0004 class PointNetRegressor(nn.Module):
0005     def __init__(self, input_dim=2, emb_dim=64):
0006         super().__init__()
0007         self.point_mlp = nn.Sequential(
0008             nn.Linear(input_dim, 64),
0009             nn.ReLU(),
0010             nn.Linear(64, emb_dim),
0011             nn.ReLU()
0012         )
0013         self.regressor = nn.Sequential(
0014             nn.Linear(emb_dim, 64),
0015             nn.ReLU(),
0016             nn.Linear(64, 1)
0017         )
0018 
0019     def forward(self, x):
0020         # x shape: [B, N, 2]
0021         feat = self.point_mlp(x)            # [B, N, emb_dim]
0022         global_feat = torch.max(feat, dim=1)[0]  # [B, emb_dim]
0023         out = self.regressor(global_feat)   # [B, 1]
0024         return out.squeeze(1)               # [B]