Skip to content

Commit

Permalink
add trainer and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Nov 2, 2023
1 parent 1745f8e commit 6fcf832
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ __pycache__/
wandb/
models/
.vscode
*.egg-info
*.egg-info
*tfevents*
6 changes: 5 additions & 1 deletion neurve/nn_encoder/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ def __getitem__(
neighbors = self.data[self.indices[index][: self.n_neighbors]]
non_neighbors = self.data[self.sample_non_neighbor(index)]

return point, neighbors, non_neighbors
return (
point.astype(np.float32),
neighbors.astype(np.float32),
non_neighbors.astype(np.float32),
)
6 changes: 3 additions & 3 deletions neurve/nn_encoder/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def loss_at_a_point(
).sum()

# encourage non neighbors to be further than neighbors
loss_non_neighbors = -(
latent_non_neighbors_dist2 - latent_neighbors_dist2.max()
)
loss_non_neighbors = (
-(latent_non_neighbors_dist2 - latent_neighbors_dist2.max())
).sum()

return {
"loss_neighbors": loss_neighbors,
Expand Down
97 changes: 97 additions & 0 deletions neurve/nn_encoder/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch

from neurve.core import Trainer
from neurve.nn_encoder.loss import loss
from neurve.nn_encoder.models import MfldEncoder

# loss: at a point, take the nearest neighbors and look at squared error between distance in
# og space and distance in latent space (with a learned scale parameter) and then sample poitns
# that are not NN and add a term that has the distance from the point to those bigger than some margin
# from the max (or smooth maximum) with the distance to the nearest neighbors (but capped). actually
# maybe don't need to cap since vectors will be restricted to inside unit ball. so then maybe just have term
# that encourages these other points to be far away and then don't need to worry about max


class NNEncoderTrainer(Trainer):
def __init__(
self,
net: MfldEncoder,
opt: torch.optim.Optimizer,
out_path: str,
reg_loss_weight: float,
c: float,
data_loader: torch.utils.data.DataLoader,
net_name: str = "net",
eval_data_loader: torch.utils.data.DataLoader = None,
device: torch.device = None,
q_loss_weight: float = 0.0,
use_wandb: bool = False,
):
super().__init__(
net=net,
opt=opt,
out_path=out_path,
data_loader=data_loader,
net_name=net_name,
eval_data_loader=eval_data_loader,
device=device,
use_wandb=use_wandb,
)
self.reg_loss_weight = reg_loss_weight
self.q_loss_weight = q_loss_weight
self.scale = torch.rand(1, requires_grad=True, device=self.device)
self.c = c

def _train_step(self, data):
batch_size = data[0].shape[0]

points, neighbors, non_neighbors = data

batch_size, n_neighbors, _ = neighbors.shape

all_points = torch.cat(
[points, neighbors.flatten(0, 1), non_neighbors.flatten(0, 1)],
dim=0,
)
q, coords = self.net(all_points)

sections = [
batch_size,
n_neighbors * batch_size,
n_neighbors * batch_size,
]

q_point, q_neighbors, q_non_neighbors = q.split(sections)
q_neighbors = q_neighbors.unflatten(0, (batch_size, n_neighbors))
q_non_neighbors = q_non_neighbors.unflatten(
0, (batch_size, n_neighbors)
)

coords_point, coords_neighbors, coords_non_neighbors = coords.split(
sections
)
coords_neighbors = coords_neighbors.unflatten(
0, (batch_size, n_neighbors)
)
coords_non_neighbors = coords_non_neighbors.unflatten(
0, (batch_size, n_neighbors)
)

loss_dict = loss(
point=points,
neighbors=neighbors,
q_point=q_point,
q_neighbors=q_neighbors,
q_non_neighbors=q_non_neighbors,
coords_point=coords_point,
coords_neighbors=coords_neighbors,
coords_non_neighbors=coords_non_neighbors,
scale=self.scale,
c=self.c,
)

self.opt.zero_grad()
loss_dict["loss"].backward()
self.opt.step()

return loss_dict
76 changes: 57 additions & 19 deletions tests/test_nn_encoder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import numpy as np
import pytest
import torch

from neurve.nn_encoder.dataset import NNDataset
from neurve.nn_encoder.loss import loss, loss_at_a_point
from neurve.nn_encoder.models import MfldEncoder
from neurve.nn_encoder.trainer import NNEncoderTrainer

n = 8
z = 2
batch = 2
n_charts = 3
c = 0.1

def test_mfld_encoder():
n = 8
z = 2
batch = 2
n_charts = 3

net = MfldEncoder(
@pytest.fixture
def net() -> MfldEncoder:
return MfldEncoder(
n=n,
z=z,
backbone_dim=4,
Expand All @@ -21,6 +25,13 @@ def test_mfld_encoder():
use_batch_norm=False,
)


@pytest.fixture
def dataset() -> NNDataset:
return NNDataset(data=np.random.rand(15, n), n_neighbors=4)


def test_mfld_encoder(net: MfldEncoder):
x = torch.rand(batch, n)
q, coords = net(x)

Expand All @@ -33,9 +44,6 @@ def test_mfld_encoder():


def test_loss():
n = 8
z = 2
n_charts = 3
n_neighbors = 4
n_non_neighbors = 5

Expand All @@ -48,7 +56,6 @@ def test_loss():
coords_neighbors = torch.rand(n_neighbors, n_charts, z)
coords_non_neighbors = torch.rand(n_non_neighbors, n_charts, z)
scale = torch.Tensor([0.75])
c = 0.1

loss_dict = loss_at_a_point(
point=point,
Expand Down Expand Up @@ -89,17 +96,48 @@ def _double(x):
assert v.allclose(batch_loss_dict[k])


def test_dataset():
data = np.random.rand(15, 2)
dset = NNDataset(data=data, n_neighbors=4)

assert len(dset) == 15
pt, nbrs, non_nbrs = dset[2]
assert np.array_equal(pt, data[2])
assert nbrs.shape == (4, 2)
assert non_nbrs.shape == (4, 2)
def test_dataset(dataset: NNDataset):
assert len(dataset) == 15
pt, nbrs, non_nbrs = dataset[2]
assert pt.dtype == np.float32
assert np.array_equal(pt, dataset.data[2].astype(np.float32))
assert nbrs.shape == (4, n)
assert non_nbrs.shape == (4, n)

nbr_dists = np.linalg.norm(nbrs - pt, axis=1)
non_nbr_dists = np.linalg.norm(non_nbrs - pt, axis=1)

assert nbr_dists.max() < non_nbr_dists.min()


def test_trainer(net: MfldEncoder, dataset: NNDataset):
batch_size = 2
opt = torch.optim.Adam(net.parameters())
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True
)

inital_params = [p.detach().clone() for p in net.parameters()]

trainer = NNEncoderTrainer(
net=net,
opt=opt,
out_path="out",
reg_loss_weight=0.4,
data_loader=data_loader,
c=c,
)

trainer.train(n_epochs=2)

assert trainer.global_steps == 2 * (len(dataset.data) // batch_size)

final_weights = list(net.parameters())

# check weights changed
a_weight_changed = False
for init, final in zip(inital_params, final_weights):
if not torch.allclose(init, final):
a_weight_changed = True
break
assert a_weight_changed

0 comments on commit 6fcf832

Please sign in to comment.