File indexing completed on 2025-12-16 09:18:06
0001 import torch
0002 import torch.nn as nn
0003
0004 class PointNetRegressor(nn.Module):
0005 def __init__(self, input_dim=2, emb_dim=64):
0006 super().__init__()
0007 self.point_mlp = nn.Sequential(
0008 nn.Linear(input_dim, 64),
0009 nn.ReLU(),
0010 nn.Linear(64, emb_dim),
0011 nn.ReLU()
0012 )
0013 self.regressor = nn.Sequential(
0014 nn.Linear(emb_dim, 64),
0015 nn.ReLU(),
0016 nn.Linear(64, 1)
0017 )
0018
0019 def forward(self, x):
0020
0021 feat = self.point_mlp(x)
0022 global_feat = torch.max(feat, dim=1)[0]
0023 out = self.regressor(global_feat)
0024 return out.squeeze(1)