Skip to content

Commit

Permalink
Added VAE Results
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 16, 2020
1 parent 7f72790 commit 291a37d
Show file tree
Hide file tree
Showing 34 changed files with 1,241 additions and 159 deletions.
1 change: 1 addition & 0 deletions .idea/PyTorch-VAE.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

233 changes: 233 additions & 0 deletions .ipynb_checkpoints/Run-checkpoint.ipynb

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
PyTorch-VAE

TODO
- [ ] VanillaVAE
- [x] VanillaVAE
- [ ] Conditional VAE
- [ ] Gamma VAE
- [ ] Beta VAE
- [ ] InfoVAE
- [ ] WAE
- [ ] DFC VAE
- [ ] InfoVAE (MMD-VAE)
- [ ] WAE-MMD
- [ ] AAE
- [ ] TwoStageVAE
- [ ] MMD-VAE
- [ ] VAE-GAN
- [ ] VAE with Vamp Prior
- [ ] IWAE
- [ ] VLAE
- [ ] FactorVAE

226 changes: 226 additions & 0 deletions Run.ipynb

Large diffs are not rendered by default.

Binary file added __pycache__/experiment.cpython-36.pyc
Binary file not shown.
Binary file removed __pycache__/trainer.cpython-37.pyc
Binary file not shown.
Binary file added assets/Vanilla VAE_25.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions configs/vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model_params:
name: 'VanillaVAE'
in_channels: 3
latent_dim: 128

data_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144

optimizer_params:
LR: 5e-3
scheduler_gamma: 0.95
gpus: 1
num_epochs: 50

logging_params:
save_dir: "logs/",
name: "VanillaVAE",




114 changes: 114 additions & 0 deletions experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import pytorch_lightning as pl
from models import BaseVAE
from torchvision import transforms
from torchvision.datasets import CelebA
from torch import optim
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from models.types_ import *
import math


class VAEXperiment(pl.LightningModule):

def __init__(self,
vae_model: BaseVAE,
params: dict) -> None:
super(VAEXperiment, self).__init__()

self.model = vae_model
self.params = params
self.curr_device = None

def forward(self, input: Tensor):
return self.model(input)

def training_step(self, batch, batch_idx):
real_img, _ = batch
self.curr_device = real_img.device

recons_img, mu, log_var = self.forward(real_img)

train_loss = self.model.loss_function(recons_img,
real_img,
mu,
log_var,
M_N = self.params.batch_size/ self.num_train_imgs )

self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})

return train_loss

def validation_step(self, batch, batch_idx):
real_img, _ = batch
recons_img, mu, log_var = self.forward(real_img)
val_loss = self.model.loss_function(recons_img,
real_img,
mu,
log_var,
M_N = self.params.batch_size/ self.num_train_imgs )

# self.logger.experiment.log({key: val.item() for key, val in val_loss.items()})
return val_loss

def validation_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
tensorboard_logs = {'avg_val_loss': avg_loss}
self.sample_images()
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def sample_images(self):
z = torch.randn(self.params.batch_size,
self.model.latent_dim)

if self.on_gpu:
z = z.cuda(self.curr_device)

samples = self.model.decode(z).cpu()

vutils.save_image(samples.data,
f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params.batch_size)))


def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.params.LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params.scheduler_gamma)
return [optimizer] #, [scheduler]

@pl.data_loader
def train_dataloader(self):
transform = self.data_transforms()
dataset = CelebA(root = self.params.data_path,
split = "train",
transform=transform,
download=False)
self.num_train_imgs = len(dataset)
return DataLoader(dataset,
batch_size= self.params.batch_size,
shuffle = True,
drop_last=True)

@pl.data_loader
def val_dataloader(self):
transform = self.data_transforms()

return DataLoader(CelebA(root = self.params.data_path,
split = "test",
transform=transform,
download=False),
batch_size= self.params.batch_size,
shuffle = True,
drop_last=True)

def data_transforms(self):
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(self.params.img_size),
transforms.ToTensor(),
SetRange])
return transform

119 changes: 119 additions & 0 deletions models/.ipynb_checkpoints/vanilla_vae-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class VanillaVAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None) -> None:
super(VanillaVAE, self).__init__()

self.latent_dim = latent_dim

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.ReLU())
)
in_channels = h_dim

modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels= 2*latent_dim,
kernel_size=3, stride=1, padding = 1),
nn.BatchNorm2d(2*latent_dim),
nn.ReLU())
)

self.encoder = nn.Sequential(*modules)

# Build Decoder
modules = []
in_channels = latent_dim

for _ in range(len(hidden_dims)):
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=64,
kernel_size= 3, padding= 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True))
)
in_channels = 64

self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.Conv2d(64, out_channels= 3,
kernel_size= 3, padding= 1),
nn.Sigmoid())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = result[:, :self.latent_dim, :, :]
log_var = result[:, self.latent_dim:, :, :]

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
result = self.decoder(z)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var

def loss_function(self,
recons: Tensor,
input: Tensor,
mu: Tensor,
log_var: Tensor) -> Tensor:

recons_loss =F.mse_loss(recons,
input,
reduction='mean')


kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp())
kld_loss /= input.size(0)
return recons_loss + kld_loss


5 changes: 5 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .base import *
from .vanilla_vae import *
from .gamma_vae import *
from .beta_vae import *
from .wae_mmd import *
from .cvae import *

# Aliases
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
Binary file added models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/base.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/base.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file added models/__pycache__/beta_vae.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file added models/__pycache__/gamma_vae.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/types_.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/types_.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/wae_mmd.cpython-36.pyc
Binary file not shown.
Loading

0 comments on commit 291a37d

Please sign in to comment.