Skip to content

Commit

Permalink
Updated assets
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 28, 2020
1 parent 91ee099 commit f658659
Show file tree
Hide file tree
Showing 16 changed files with 298 additions and 10 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ logging_params:
| MSSIM VAE |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
| Info VAE |[Link](https://arxiv.org/abs/1706.02262) | ![][22] | ![][21] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
<!--| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |-->
Expand All @@ -106,12 +107,12 @@ logging_params:
- [x] Conditional VAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [ ] InfoVAE (in progress)
- [ ] Gamma VAE (in progress)
- [ ] Beta TC-VAE (in progress)
- [ ] Vamp VAE (in progress)
- [ ] HVAE (VAE with Vamp Prior) (in progress)
- [ ] FactorVAE (in progress)
- [ ] InfoVAE
- [ ] TwoStageVAE
- [ ] VAE-GAN
- [ ] VLAE
Expand Down Expand Up @@ -154,10 +155,10 @@ I would be happy to include your result (along with your config file) in this re
[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_12.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_12.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_6.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_6.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_49.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png
[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file removed assets/CategoricalVAE_12.png
Binary file not shown.
Binary file added assets/CategoricalVAE_49.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/JointVAE_49.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/JointVAE_6.png
Binary file not shown.
Binary file removed assets/recons_CategoricalVAE_12.png
Binary file not shown.
Binary file added assets/recons_CategoricalVAE_49.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_JointVAE_49.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/recons_JointVAE_6.png
Binary file not shown.
2 changes: 1 addition & 1 deletion configs/bhvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model_params:
in_channels: 3
latent_dim: 128
loss_type: 'H'
gamma: 1000.0
gamma: 10.0
max_capacity: 25
Capacity_max_iter: 10000

Expand Down
2 changes: 1 addition & 1 deletion configs/cat_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model_params:
temperature: 0.5
anneal_rate: 0.00003
anneal_interval: 100
alpha: 8.0
alpha: 1.0

exp_params:
dataset: celeba
Expand Down
31 changes: 31 additions & 0 deletions configs/infovae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
model_params:
name: 'InfoVAE'
in_channels: 3
latent_dim: 128
reg_weight: 110 # Lambda factor
kernel_type: 'imq'
alpha: -9.0

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: [3]
max_nb_epochs: 50
max_epochs: 50
gradient_clip_val: 0.8

logging_params:
save_dir: "logs/"
name: "InfoVAE"
manual_seed: 1265




4 changes: 2 additions & 2 deletions configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 100
reg_weight: 1000
kernel_type: 'rbf'

exp_params:
Expand All @@ -15,7 +15,7 @@ exp_params:
scheduler_gamma: 0.95

trainer_params:
gpus: [1]
gpus: [2]
max_nb_epochs: 50
max_epochs: 50

Expand Down
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .fvae import *
from .cat_vae import *
from .joint_vae import *
from .info_vae import *

# Aliases
VAE = VanillaVAE
Expand All @@ -31,4 +32,5 @@
'MSSIMVAE':MSSIMVAE,
'FactorVAE':FactorVAE,
'CategoricalVAE':CategoricalVAE,
'JointVAE':JointVAE}
'JointVAE':JointVAE,
'InfoVAE':InfoVAE}
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
254 changes: 254 additions & 0 deletions models/info_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class InfoVAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha: float = -0.5,
reg_weight: int = 100,
kernel_type: str = 'imq',
latent_var: float = 2.,
**kwargs) -> None:
super(InfoVAE, self).__init__()

self.latent_dim = latent_dim
self.reg_weight = reg_weight
self.kernel_type = kernel_type
self.z_var = latent_var

assert alpha <= 0, 'alpha must be negative or zero.'

self.alpha = alpha

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)

# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, z, mu, log_var]

def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z = args[2]
mu = args[3]
log_var = args[4]

batch_size = input.size(0)
bias_corr = batch_size * (batch_size - 1)
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

recons_loss =F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

loss = recons_loss + \
(1. - self.alpha) * kld_weight * kld_loss + \
(self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}

def compute_kernel(self,
x1: Tensor,
x2: Tensor) -> Tensor:
# Convert the tensors into row and column vectors
D = x1.size(1)
N = x1.size(0)

x1 = x1.unsqueeze(-2) # Make it into a column tensor
x2 = x2.unsqueeze(-3) # Make it into a row tensor

"""
Usually the below lines are not required, especially in our case,
but this is useful when x1 and x2 have different sizes
along the 0th dimension.
"""
x1 = x1.expand(N, N, D)
x2 = x2.expand(N, N, D)

if self.kernel_type == 'rbf':
result = self.compute_rbf(x1, x2)
elif self.kernel_type == 'imq':
result = self.compute_inv_mult_quad(x1, x2)
else:
raise ValueError('Undefined kernel type.')

return result


def compute_rbf(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
"""
Computes the RBF Kernel between x1 and x2.
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
:return:
"""
z_dim = x2.size(-1)
sigma = 2. * z_dim * self.z_var

result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
return result

def compute_inv_mult_quad(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
"""
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
given by
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
:return:
"""
z_dim = x2.size(-1)
C = 2 * z_dim * self.z_var
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))

# Exclude diagonal elements
result = kernel.sum() - kernel.diag().sum()

return result

def compute_mmd(self, z: Tensor) -> Tensor:
# Sample from prior (Gaussian) distribution
prior_z = torch.randn_like(z)

prior_z__kernel = self.compute_kernel(prior_z, prior_z)
z__kernel = self.compute_kernel(z, z)
priorz_z__kernel = self.compute_kernel(prior_z, z)

mmd = prior_z__kernel.mean() + \
z__kernel.mean() - \
2 * priorz_z__kernel.mean()
return mmd

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)

z = z.to(current_device)

samples = self.decode(z)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]

0 comments on commit f658659

Please sign in to comment.