Skip to content

Commit

Permalink
Added Gumbel VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 24, 2020
1 parent a316df7 commit f71e454
Show file tree
Hide file tree
Showing 26 changed files with 294 additions and 29 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ logging_params:
| Model | Paper |Reconstruction | Samples |
|-----------------------|--------------------------------------------------|---------------|---------|
| VAE |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] |
| Conditional VAE |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] |
| Conditional VAE |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] |
| WAE - MMD (RBF Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
Expand All @@ -107,6 +107,7 @@ logging_params:
- [ ] Vamp VAE (in progress)
- [ ] HVAE (VAE with Vamp Prior) (in progress)
- [ ] FactorVAE (in progress)
- [ ] Catagorical VAE (Gumbel-Softmax VAE)
- [ ] InfoVAE
- [ ] TwoStageVAE
- [ ] VAE-GAN
Expand Down Expand Up @@ -146,8 +147,8 @@ I would be happy to include your result (along with your config file) in this re
[12]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DFCVAE_49.png
[13]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MSSIMVAE_29.png
[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_4.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_4.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
Expand Down
Binary file added assets/ConditionalVAE_20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/ConditionalVAE_4.png
Binary file not shown.
Binary file added assets/recons_ConditionalVAE_20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/recons_ConditionalVAE_4.png
Binary file not shown.
27 changes: 27 additions & 0 deletions configs/cat_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model_params:
name: 'CategoricalVAE'
in_channels: 3
num_classes: 40
latent_dim: 128
categorical_dim: 40 # Equal to Num classes
temperature: 0.5
anneal_rate: 3e-5
annela_interval: 100

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "CategoricalVAE"
manual_seed: 1265
8 changes: 5 additions & 3 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ def training_step(self, batch, batch_idx, optimizer_idx = 0):
results = self.forward(real_img, labels = labels)

real_img2 = None

try:
# Required for factor VAE
if self.params['require_secondary_input']:
real_img2,_ = next(iter(self.sample_dataloader))
real_img2 = real_img.to(self.curr_device)
except:
pass

train_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs,
optimizer_idx=optimizer_idx,
secondary_input = real_img2)
secondary_input = real_img2,
batch_idx = batch_idx)

self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})

Expand All @@ -57,7 +58,8 @@ def validation_step(self, batch, batch_idx, optimizer_idx = 0):
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs,
optimizer_idx = optimizer_idx)
optimizer_idx = optimizer_idx,
batch_idx = batch_idx)

return val_loss

Expand Down
5 changes: 4 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from .dfcvae import *
from .mssim_vae import MSSIMVAE
from .fvae import *
from .cat_vae import *

# Aliases
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
GUMBELVAE = CategoricalVAE

vae_models = {'VanillaVAE':VanillaVAE,
'WAE_MMD':WAE_MMD,
Expand All @@ -26,4 +28,5 @@
'IWAE':IWAE,
'DFCVAE':DFCVAE,
'MSSIMVAE':MSSIMVAE,
'FactorVAE':FactorVAE}
'FactorVAE':FactorVAE,
'CategoricalVAE':CategoricalVAE}
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/base.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def encode(self, input: Tensor) -> List[Tensor]:
def decode(self, input: Tensor) -> Any:
raise NotImplementedError

def sample(self, batch_size:int, current_device: int) -> Tensor:
def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
raise NotImplementedError

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def loss_function(self,

def sample(self,
num_samples:int,
current_device: int) -> Tensor:
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
Expand All @@ -169,7 +169,7 @@ def sample(self,
samples = self.decode(z)
return samples

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
Expand Down
200 changes: 200 additions & 0 deletions models/cat_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import torch
import numpy as np
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class CategoricalVAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
categorical_dim: int = 40, # Num classes
hidden_dims: List = None,
temperature: float = 0.5,
anneal_rate: float = 3e-5,
annela_interval: int = 100, # every 100 batches
**kwargs) -> None:
super(CategoricalVAE, self).__init__()

self.latent_dim = latent_dim
self.categorical_dim = categorical_dim
self.temp = temperature
self.min_temp = temperature
self.anneal_rate = anneal_rate
self.anneal_interval = annela_interval

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_z = nn.Linear(hidden_dims[-1]*4,
self.latent_dim * self.categorical_dim)

# Build Decoder
modules = []

self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim
, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [B x C x H x W]
:return: (Tensor) Latent code [B x D x Q]
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
z = self.fc_z(result)
z = z.view(-1, self.latent_dim, self.categorical_dim)
return [z]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
"""
Gumbel-softmax trick to sample from Categorical Distribution
:param z: (Tensor) Latent Codes [B x D x Q]
:return: (Tensor) [B x D]
"""
# Sample from Gumbel
u = torch.rand_like(z)
g = - torch.log(- torch.log(u + eps) + eps)

# Gumbel-Softmax sample
s = F.softmax((z + g) / self.temp, dim=-1)
s = s.view(-1, self.latent_dim * self.categorical_dim)
return s


def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
q = self.encode(input)[0]
z = self.reparameterize(q)
return [self.decode(z), input, q]

def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
q = args[2]

q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
batch_idx = kwargs['batch_idx']

# Anneal the temperature at regular intervals
# if batch_idx % self.anneal_interval == 0:
# self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
# self.min_temp)

recons_loss =F.mse_loss(recons, input)

# KL divergence between gumbel-softmax distribution
eps = 1e-7

# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)

# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)

z = z.to(current_device)

samples = self.decode(z)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]
4 changes: 2 additions & 2 deletions models/dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def loss_function(self,

def sample(self,
num_samples:int,
current_device: int) -> Tensor:
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
Expand All @@ -207,7 +207,7 @@ def sample(self,
samples = self.decode(z)
return samples

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
Expand Down
4 changes: 2 additions & 2 deletions models/fvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def loss_function(self,

def sample(self,
num_samples:int,
current_device: int) -> Tensor:
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
Expand All @@ -220,7 +220,7 @@ def sample(self,
samples = self.decode(z)
return samples

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
Expand Down
4 changes: 2 additions & 2 deletions models/gamma_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def loss_function(self,

def sample(self,
num_samples:int,
current_device: int) -> Tensor:
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
Expand All @@ -227,7 +227,7 @@ def sample(self,
samples = self.decode(z)
return samples

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
Expand Down
Loading

0 comments on commit f71e454

Please sign in to comment.