Skip to content

Commit

Permalink
Added MIWAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 12, 2020
1 parent 55e654f commit 4cac7e7
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 6 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ logging_params:
| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] |
| IWAE (5 Samples) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] |
| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] |
| MIWAE (*K = 5, M = 5*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] |
| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE ([Code][jointvae_code], [Config][jointvae_config]) |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
| Info VAE ([Code][infovae_code], [Config][infovae_config]) |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] |
| LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] |
| SWAE (50 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] |
| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
Expand All @@ -104,14 +105,15 @@ logging_params:
- [x] DFC VAE
- [x] MSSIM VAE
- [x] IWAE
- [x] MIWAE
- [x] WAE-MMD
- [x] Conditional VAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [x] InfoVAE
- [x] LogCosh VAE
- [ ] SWAE (in progress)
- [x] SWAE
- [ ] Ladder VAE (in progress)
- [ ] Gamma VAE (in progress)
- [ ] Vamp VAE (in progress)
Expand Down Expand Up @@ -155,6 +157,7 @@ I would be happy to include your result (along with your config file) in this re
[bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
[wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py
[iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py
[miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py
[swae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/swae.py
[jointvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/joint_vae.py
[dfcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dfcvae.py
Expand All @@ -170,6 +173,7 @@ I would be happy to include your result (along with your config file) in this re
[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml
[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml
[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml
[miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml
[swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml
[jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml
[dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml
Expand Down Expand Up @@ -206,6 +210,8 @@ I would be happy to include your result (along with your config file) in this re
[26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_49.png
[27]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/SWAE_49.png
[28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png
[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png
[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file added assets/MIWAE_29.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_MIWAE_29.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions configs/miwae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model_params:
name: 'MIWAE'
in_channels: 3
latent_dim: 128
num_samples: 5
num_estimates: 3

exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 30

logging_params:
save_dir: "logs/"
name: "MIWAE"
manual_seed: 1265
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .lvae import LVAE
from .logcosh_vae import *
from .swae import *
from .miwae import *


# Aliases
Expand All @@ -41,4 +42,5 @@
'InfoVAE':InfoVAE,
'LVAE':LVAE,
'LogCoshVAE':LogCoshVAE,
'SWAE':SWAE}
'SWAE':SWAE,
'MIWAE':MIWAE}
4 changes: 2 additions & 2 deletions models/iwae.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def loss_function(self,

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2)
log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S]
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S]
# Get importance weights
log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data

Expand Down
192 changes: 192 additions & 0 deletions models/miwae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from torch.distributions import Normal


class MIWAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
num_samples: int = 5,
num_estimates: int = 5,
**kwargs) -> None:
super(MIWAE, self).__init__()

self.latent_dim = latent_dim
self.num_samples = num_samples # K
self.num_estimates = num_estimates # M

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes of S samples
onto the image space.
:param z: (Tensor) [B x S x D]
:return: (Tensor) [B x S x C x H x W]
"""
B, M,S, D = z.size()
z = z.view(-1, self.latent_dim) #[BMS x D]
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result) #[BMS x C x H x W ]
result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W]
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
z = self.reparameterize(mu, log_var) # [B x M x S x D]
eps = (z - mu) / log_var # Prior samples
return [self.decode(z), input, mu, log_var, z, eps]

def loss_function(self,
*args,
**kwargs) -> dict:
"""
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]
eps = args[5]

input = input.repeat(self.num_estimates,
self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W]

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S]

kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S]
# Get importance weights
log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data

# Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
weight = F.softmax(log_weight, dim = -1) # [B x M x S]

loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0)

return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples, 1, 1,
self.latent_dim)

z = z.to(current_device)

samples = self.decode(z).squeeze()
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image.
Returns only the first reconstructed sample
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0][:, 0, 0, :]
1 change: 1 addition & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
log_save_interval=100,
train_percent_check=1.,
val_percent_check=1.,
num_sanity_val_steps=0,
early_stop_callback = False,
**config['trainer_params'])

Expand Down
42 changes: 42 additions & 0 deletions tests/test_miwae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import unittest
from models import MIWAE
from torchsummary import summary


class TestMIWAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
self.model = MIWAE(3, 10)

def test_summary(self):
print(summary(self.model, (3, 64, 64), device='cpu'))
# print(summary(self.model2, (3, 64, 64), device='cpu'))

def test_forward(self):
x = torch.randn(16, 3, 64, 64)
y = self.model(x)
print("Model Output size:", y[0].size())
# print("Model2 Output size:", self.model2(x)[0].size())

def test_loss(self):
x = torch.randn(16, 3, 64, 64)

result = self.model(x)
loss = self.model.loss_function(*result, M_N = 0.005)
print(loss)

def test_sample(self):
self.model.cuda()
y = self.model.sample(144, 0)
print(y.shape)

def test_generate(self):
x = torch.randn(16, 3, 64, 64)
y = self.model.generate(x)
print(y.shape)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4cac7e7

Please sign in to comment.