Skip to content

Commit

Permalink
manifold encoder and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Oct 6, 2023
1 parent 9bf3adb commit 5a1eca2
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
61 changes: 61 additions & 0 deletions neurve/nn_encoder/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch.nn as nn


def create_basic_layer(
in_dim: int, out_dim: int, use_batch_norm: bool
) -> nn.Module:
layers = [nn.Linear(in_dim, out_dim), nn.ReLU()]
if use_batch_norm:
layers.append(nn.BatchNorm1d(out_dim))
return nn.Sequential(*layers)


class MfldEncoder(nn.Module):
def __init__(
self,
n: int,
z: int,
backbone_dim: int,
hidden_dim: int,
n_charts: int,
use_batch_norm: bool,
):
super().__init__()
self.backbone = nn.Sequential(
create_basic_layer(n, hidden_dim, use_batch_norm),
create_basic_layer(hidden_dim, backbone_dim, use_batch_norm),
)

self.q = nn.Sequential(
create_basic_layer(
in_dim=backbone_dim,
out_dim=n_charts,
use_batch_norm=use_batch_norm,
),
nn.Softmax(1),
)

self.coord_maps = [
create_basic_layer(
in_dim=backbone_dim,
out_dim=z,
use_batch_norm=use_batch_norm,
)
for _ in range(n_charts)
]

def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]:
"""encodes the input tensor x
Returns
-------
first tensor returned is the chart membership probabilities, shape (batch, n_charts)
second tensor returned is the coordinates in each chart, shape (batch, n_charts, z)
"""
x = self.backbone(x)
coords = [c(x) for c in self.coord_maps]
coords = torch.stack(coords, 1)
q = self.q(x)

return q, coords
28 changes: 28 additions & 0 deletions tests/test_nn_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from neurve.nn_encoder.models import MfldEncoder


def test_mfld_encoder():
n = 8
z = 2
batch = 2
n_charts = 3

net = MfldEncoder(
n=n,
z=z,
backbone_dim=4,
hidden_dim=6,
n_charts=n_charts,
use_batch_norm=False,
)

x = torch.rand(batch, n)
q, coords = net(x)

assert q.shape == (batch, n_charts)
assert coords.shape == (batch, n_charts, z)

assert q.max() <= 1
assert q.min() >= 0
assert q.sum(1).allclose(torch.ones(batch))

0 comments on commit 5a1eca2

Please sign in to comment.