Skip to content

Commit

Permalink
Updated trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 14, 2020
1 parent 670a8af commit 7f72790
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

data/
logs/

VanillaVAE/version_0/
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
PyTorch-VAE

TODO
[ ] VanillaVAE
[ ] Conditional VAE
[ ] Gamma VAE
[ ] Beta VAE
[ ] InfoVAE
[ ] WAE
[ ] AAE
[ ] TwoStageVAE
[ ] MMD-VAE
[ ] VAE-GAN
[ ] VAE with Vamp Prior
[ ] IWAE
[ ] VLAE
- [ ] VanillaVAE
- [ ] Conditional VAE
- [ ] Gamma VAE
- [ ] Beta VAE
- [ ] InfoVAE
- [ ] WAE
- [ ] AAE
- [ ] TwoStageVAE
- [ ] MMD-VAE
- [ ] VAE-GAN
- [ ] VAE with Vamp Prior
- [ ] IWAE
- [ ] VLAE

Binary file modified __pycache__/trainer.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-37.pyc
Binary file not shown.
14 changes: 6 additions & 8 deletions models/gamma_vae.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch
from models import BaseVAE
from torch import nn
from torch.distributions import Gamma
from torch.nn import functional as F
from .types_ import *


class GammaVAE(BaseVAE):
"""
https://github.com/darleybarreto/vae-pytorch/blob/419d861089ab2a84ff154d550866629e526ff81f/models/gamma_vae.py
"""

def __init__(self,
in_channels: int,
Expand Down Expand Up @@ -84,14 +88,8 @@ def decode(self, z: Tensor) -> Tensor:
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:

std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
Expand Down
10 changes: 6 additions & 4 deletions models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ def loss_function(self,
mu: Tensor,
log_var: Tensor) -> Tensor:

bce_loss = F.binary_cross_entropy(recons.view(-1), input.view(-1))
recons_loss =F.mse_loss(recons,
input,
reduction='mean')

kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp())
kld_loss /= input.view(-1).size(0)

return bce_loss + kld_loss
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp())
kld_loss /= input.size(0)
return recons_loss + kld_loss


18 changes: 11 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TestTubeLogger
# from pytorch_lightning.callbacks import ModelCheckpoint
from trainer import VAETrainer
from models import VAE
import torch


tt_logger = TestTubeLogger(
Expand All @@ -11,26 +13,28 @@
create_git_tag=False
)



class hparams(object):
def __init__(self):
self.LR = 0.001
self.LR = 0.0005
self.momentum = 0.9
self.scheduler_gamma = 0
self.gpus = 1
self.data_path = 'data/'
self.batch_size = 32
self.batch_size = 144
self.manual_seed = 1256

hyper_params = hparams()

model = VAE(3, 32)
torch.manual_seed(hyper_params.manual_seed)
model = VAE(in_channels=3, latent_dim=128)
net = VAETrainer(model,
hyper_params)


trainer = Trainer(gpus=hyper_params.gpus,
min_nb_epochs=1,
max_nb_epochs=2,
logger=tt_logger)
logger=tt_logger,
log_save_interval=100,
train_percent_check=1.,
val_percent_check=1.)
trainer.fit(net)
38 changes: 28 additions & 10 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import pytorch_lightning as pl
from models import BaseVAE
from torchvision import transforms
Expand All @@ -17,10 +18,17 @@ def __init__(self,

self.model = vae_model
self.params = params
torch.manual_seed(self.params.manual_seed)
self.curr_device = None

def forward(self, input: Tensor):
return self.model(input)

def training_step(self, batch, batch_idx):
real_img, _ = batch

self.curr_device = real_img.device

recons_img, mu, log_var = self.model(real_img)
loss = self.model.loss_function(recons_img, real_img, mu, log_var)

Expand All @@ -34,10 +42,27 @@ def validation_step(self, batch, batch_idx):
loss = self.model.loss_function(recons_img, real_img, mu, log_var)

self.logger.experiment.log({'val_loss': loss.item()})
return {'loss': loss}
return {'val_loss': loss}

def validation_end(self, outputs):
pass
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def on_epoch_end(self):
z = torch.randn(self.params.batch_size,
128).view(self.params.batch_size, -1, 1, 1)

if self.on_gpu:
z = z.cuda(self.curr_device)

samples = self.model.decode(z).cpu()
# print(samples.shape)
grid = vutils.make_grid(samples, nrow=12)
# print(grid.shape)
self.logger.experiment.add_image(f'Samples', grid, self.current_epoch)
vutils.save_image(samples.data, f"sample_{self.current_epoch}.png", normalize=True, nrow=12)


def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.params.LR)
Expand Down Expand Up @@ -67,11 +92,4 @@ def val_dataloader(self):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
download=True),
batch_size= self.params.batch_size,
drop_last=True)

# Utils
def save_samples(self):
recons_img = 0
vutils.save_image(recons_img,
f"{self.logger.save_dir}/fake_samples.png",
normalize=True)
drop_last=True)

0 comments on commit 7f72790

Please sign in to comment.