From d282f67682e8a826b3d5d97e5180ced83d8dfe93 Mon Sep 17 00:00:00 2001 From: Anand Date: Mon, 27 Jan 2020 16:55:16 +0900 Subject: [PATCH] Added Joint VAE --- README.md | 9 +- configs/cat_vae.yaml | 11 +- configs/joint_vae.yaml | 36 +++ experiment.py | 2 +- models/__init__.py | 6 +- models/__pycache__/__init__.cpython-36.pyc | Bin 689 -> 729 bytes models/cat_vae.py | 4 +- models/joint_vae.py | 268 +++++++++++++++++++++ tests/test_cat_vae.py | 4 +- tests/test_joint_Vae.py | 38 +++ 10 files changed, 364 insertions(+), 14 deletions(-) create mode 100644 configs/joint_vae.yaml create mode 100644 models/joint_vae.py create mode 100644 tests/test_joint_Vae.py diff --git a/README.md b/README.md index c35a24e8..21f973f6 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,8 @@ logging_params: | IWAE (5 Samples) |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] | | DFCVAE |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] | | MSSIM VAE |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] | -| Categorical VAE (CIFAR10)|[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] | +| Categorical VAE |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] | +| Joint VAE |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] | @@ -104,6 +105,7 @@ logging_params: - [x] WAE-MMD - [x] Conditional VAE - [x] Categorical VAE (Gumbel-Softmax VAE) +- [x] Joint VAE - [ ] Gamma VAE (in progress) - [ ] Beta TC-VAE (in progress) - [ ] Vamp VAE (in progress) @@ -117,7 +119,7 @@ logging_params: - [ ] VQVAE - [ ] StyleVAE - [ ] Sequential VAE -- [ ] Joint VAE + ### Contributing If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file, @@ -154,7 +156,8 @@ I would be happy to include your result (along with your config file) in this re [16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png [17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_20.png [18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_20.png - +[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_20.png +[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_20.png [python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg [python-url]: https://www.python.org/ diff --git a/configs/cat_vae.yaml b/configs/cat_vae.yaml index 23316896..e1cb3d0d 100644 --- a/configs/cat_vae.yaml +++ b/configs/cat_vae.yaml @@ -1,14 +1,15 @@ model_params: name: 'CategoricalVAE' in_channels: 3 - latent_dim: 256 - categorical_dim: 10 # Equal to Num classes + latent_dim: 512 + categorical_dim: 40 temperature: 0.5 anneal_rate: 0.00003 anneal_interval: 100 + alpha: 8.0 exp_params: - dataset: cifar10 + dataset: celeba data_path: "../../shared/Data/" img_size: 64 batch_size: 144 # Better to have a square number @@ -17,9 +18,9 @@ exp_params: scheduler_gamma: 0.95 trainer_params: - gpus: 1 + gpus: [1] max_nb_epochs: 50 - max_epochs: 250 + max_epochs: 50 logging_params: save_dir: "logs/" diff --git a/configs/joint_vae.yaml b/configs/joint_vae.yaml new file mode 100644 index 00000000..7c2daabf --- /dev/null +++ b/configs/joint_vae.yaml @@ -0,0 +1,36 @@ +model_params: + name: 'JointVAE' + in_channels: 3 + latent_dim: 512 + categorical_dim: 40 + latent_min_capacity: 0.0 + latent_max_capacity: 20.0 + latent_gamma: 30. + latent_num_iter: 25000 + categorical_min_capacity: 0.0 + categorical_max_capacity: 20.0 + categorical_gamma: 30. + categorical_num_iter: 25000 + temperature: 0.5 + anneal_rate: 0.00003 + anneal_interval: 100 + alpha: 20.0 + +exp_params: + dataset: celeba + data_path: "../../shared/Data/" + img_size: 64 + batch_size: 144 # Better to have a square number + LR: 0.005 + weight_decay: 0.0 + scheduler_gamma: 0.95 + +trainer_params: + gpus: [1] + max_nb_epochs: 50 + max_epochs: 50 + +logging_params: + save_dir: "logs/" + name: "JointVAE" + manual_seed: 1265 diff --git a/experiment.py b/experiment.py index 9a8b335d..67ea8bc3 100644 --- a/experiment.py +++ b/experiment.py @@ -205,7 +205,7 @@ def data_transforms(self): transforms.ToTensor(), SetRange]) - if self.params['dataset'] == 'cifar10': + elif self.params['dataset'] == 'cifar10': transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Lambda(lambda img: diff --git a/models/__init__.py b/models/__init__.py index 14c5e614..1a3ceed9 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -11,12 +11,13 @@ from .mssim_vae import MSSIMVAE from .fvae import * from .cat_vae import * +from .joint_vae import * # Aliases VAE = VanillaVAE GaussianVAE = VanillaVAE CVAE = ConditionalVAE -GUMBELVAE = CategoricalVAE +GumbelVAE = CategoricalVAE vae_models = {'VanillaVAE':VanillaVAE, 'WAE_MMD':WAE_MMD, @@ -29,4 +30,5 @@ 'DFCVAE':DFCVAE, 'MSSIMVAE':MSSIMVAE, 'FactorVAE':FactorVAE, - 'CategoricalVAE':CategoricalVAE} + 'CategoricalVAE':CategoricalVAE, + 'JointVAE':JointVAE} diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc index ca0e3751c5e5ccebf23abbd503145b3b2bd0d822..765a97dcc90de3590722b3768f105436aa04c807 100644 GIT binary patch delta 213 zcmdnUdXrVzn3tDJb&_6O88ZXJV+JI^2V^?{aq)_Y%J~*43@MB`d?5^}{80ixESM@3 zC7db}C7Q~gDwZmgDxNBlDw!&kDxE5m%AG2kDv~Of!aRpJN`B%2Sw`N8_eyzhad_ot z=9Ppwy83ApF delta 156 zcmcb~x{+1cn3tF9^HQxiD`p0U#|%h-7sz%1;^KJ|mGkvd`J(uNSRhp}N+?x0N+gvp zRWwyFRV-CJRU%a~RVr0Fl{-}?RXBxt4o{Tq#8a}2JQH7(GW%)DP0nI0p8SGQi#62O v$<=4FIFq*MEv~Y})cD-|l+>K!l?+9KKwF9gC&w_^GFnWY#UwL%2a`Mi*@!0^ diff --git a/models/cat_vae.py b/models/cat_vae.py index 629081ee..84212980 100644 --- a/models/cat_vae.py +++ b/models/cat_vae.py @@ -16,6 +16,7 @@ def __init__(self, temperature: float = 0.5, anneal_rate: float = 3e-5, anneal_interval: int = 100, # every 100 batches + alpha: float = 30., **kwargs) -> None: super(CategoricalVAE, self).__init__() @@ -25,6 +26,7 @@ def __init__(self, self.min_temp = temperature self.anneal_rate = anneal_rate self.anneal_interval = anneal_interval + self.alpha = alpha modules = [] if hidden_dims is None: @@ -171,7 +173,7 @@ def loss_function(self, kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0) # kld_weight = 1.2 - loss = 30. * recons_loss + kld_weight * kld_loss + loss = self.alpha * recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, diff --git a/models/joint_vae.py b/models/joint_vae.py new file mode 100644 index 00000000..4192669a --- /dev/null +++ b/models/joint_vae.py @@ -0,0 +1,268 @@ +import torch +import numpy as np +from models import BaseVAE +from torch import nn +from torch.nn import functional as F +from .types_ import * + + +class JointVAE(BaseVAE): + num_iter = 1 + + def __init__(self, + in_channels: int, + latent_dim: int, + categorical_dim: int, + latent_min_capacity: float =0., + latent_max_capacity: float = 25., + latent_gamma: float = 30., + latent_num_iter: int = 25000, + categorical_min_capacity: float =0., + categorical_max_capacity: float = 25., + categorical_gamma: float = 30., + categorical_num_iter: int = 25000, + hidden_dims: List = None, + temperature: float = 0.5, + anneal_rate: float = 3e-5, + anneal_interval: int = 100, # every 100 batches + alpha: float = 30., + **kwargs) -> None: + super(JointVAE, self).__init__() + + self.latent_dim = latent_dim + self.categorical_dim = categorical_dim + self.temp = temperature + self.min_temp = temperature + self.anneal_rate = anneal_rate + self.anneal_interval = anneal_interval + self.alpha = alpha + + self.cont_min = latent_min_capacity + self.cont_max = latent_max_capacity + + self.disc_min = categorical_min_capacity + self.disc_max = categorical_max_capacity + + self.cont_gamma = latent_gamma + self.disc_gamma = categorical_gamma + + self.cont_iter = latent_num_iter + self.disc_iter = categorical_num_iter + + modules = [] + if hidden_dims is None: + hidden_dims = [32, 64, 128, 256, 512] + + # Build Encoder + for h_dim in hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels=h_dim, + kernel_size= 3, stride= 2, padding = 1), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU()) + ) + in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + self.fc_mu = nn.Linear(hidden_dims[-1]*4, self.latent_dim) + self.fc_var = nn.Linear(hidden_dims[-1]*4, self.latent_dim) + self.fc_z = nn.Linear(hidden_dims[-1]*4, self.categorical_dim) + + # Build Decoder + modules = [] + + self.decoder_input = nn.Linear(self.latent_dim + self.categorical_dim, + hidden_dims[-1] * 4) + + hidden_dims.reverse() + + for i in range(len(hidden_dims) - 1): + modules.append( + nn.Sequential( + nn.ConvTranspose2d(hidden_dims[i], + hidden_dims[i + 1], + kernel_size=3, + stride = 2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[i + 1]), + nn.LeakyReLU()) + ) + + + + self.decoder = nn.Sequential(*modules) + + self.final_layer = nn.Sequential( + nn.ConvTranspose2d(hidden_dims[-1], + hidden_dims[-1], + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Conv2d(hidden_dims[-1], out_channels= 3, + kernel_size= 3, padding= 1), + nn.Tanh()) + self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1))) + + def encode(self, input: Tensor) -> List[Tensor]: + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [B x C x H x W] + :return: (Tensor) Latent code [B x D x Q] + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + + # Split the result into mu and var components + # of the latent Gaussian distribution + mu = self.fc_mu(result) + log_var = self.fc_var(result) + z = self.fc_z(result) + z = z.view(-1, self.categorical_dim) + return [mu, log_var, z] + + def decode(self, z: Tensor) -> Tensor: + """ + Maps the given latent codes + onto the image space. + :param z: (Tensor) [B x D x Q] + :return: (Tensor) [B x C x H x W] + """ + result = self.decoder_input(z) + result = result.view(-1, 512, 2, 2) + result = self.decoder(result) + result = self.final_layer(result) + return result + + def reparameterize(self, + mu: Tensor, + log_var: Tensor, + q: Tensor, + eps:float = 1e-7) -> Tensor: + """ + Gumbel-softmax trick to sample from Categorical Distribution + :param mu: (Tensor) mean of the latent Gaussian [B x D] + :param log_var: (Tensor) Log variance of the latent Gaussian [B x D] + :param q: (Tensor) Categorical latent Codes [B x Q] + :return: (Tensor) [B x (D + Q)] + """ + + std = torch.exp(0.5 * log_var) + e = torch.randn_like(std) + z = e * std + mu + + # Sample from Gumbel + u = torch.rand_like(q) + g = - torch.log(- torch.log(u + eps) + eps) + + # Gumbel-Softmax sample + s = F.softmax((q + g) / self.temp, dim=-1) + s = s.view(-1, self.categorical_dim) + + return torch.cat([z, s], dim=1) + + + def forward(self, input: Tensor, **kwargs) -> List[Tensor]: + mu, log_var, q = self.encode(input) + z = self.reparameterize(mu, log_var, q) + return [self.decode(z), input, q, mu, log_var] + + def loss_function(self, + *args, + **kwargs) -> dict: + """ + Computes the VAE loss function. + KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} + :param args: + :param kwargs: + :return: + """ + recons = args[0] + input = args[1] + q = args[2] + mu = args[3] + log_var = args[4] + + q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities + + + kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset + batch_idx = kwargs['batch_idx'] + + # Anneal the temperature at regular intervals + if batch_idx % self.anneal_interval == 0 and self.training: + self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), + self.min_temp) + + recons_loss =F.mse_loss(recons, input, reduction='mean') + + # Adaptively increase the discrinimator capacity + disc_curr = (self.disc_max - self.disc_min) * \ + self.num_iter/ float(self.disc_iter) + self.disc_min + disc_curr = min(disc_curr, np.log(self.categorical_dim)) + + # KL divergence between gumbel-softmax distribution + eps = 1e-7 + + # Entropy of the logits + h1 = q_p * torch.log(q_p + eps) + # Cross entropy with the categorical distribution + h2 = q_p * np.log(1. / self.categorical_dim + eps) + kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0) + + # Compute Continuous loss + # Adaptively increase the continuous capacity + cont_curr = (self.cont_max - self.cont_min) * \ + self.num_iter/ float(self.cont_iter) + self.cont_min + cont_curr = min(cont_curr, self.cont_max) + + kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), + dim=1), + dim=0) + capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss) + \ + self.cont_gamma * torch.abs(cont_curr - kld_cont_loss) + # kld_weight = 1.2 + loss = self.alpha * recons_loss + kld_weight * capacity_loss + + if self.training: + self.num_iter += 1 + return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'Capacity_Loss':capacity_loss} + + def sample(self, + num_samples:int, + current_device: int, **kwargs) -> Tensor: + """ + Samples from the latent space and return the corresponding + image space map. + :param num_samples: (Int) Number of samples + :param current_device: (Int) Device to run the model + :return: (Tensor) + """ + # [S x D] + z = torch.randn(num_samples, + self.latent_dim) + + M = num_samples + np_y = np.zeros((M, self.categorical_dim), dtype=np.float32) + np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1 + np_y = np.reshape(np_y, [M , self.categorical_dim]) + q = torch.from_numpy(np_y) + + # z = self.sampling_dist.sample((num_samples * self.latent_dim, )) + z = torch.cat([z, q], dim = 1).to(current_device) + samples = self.decode(z) + return samples + + def generate(self, x: Tensor, **kwargs) -> Tensor: + """ + Given an input image x, returns the reconstructed image + :param x: (Tensor) [B x C x H x W] + :return: (Tensor) [B x C x H x W] + """ + + return self.forward(x)[0] \ No newline at end of file diff --git a/tests/test_cat_vae.py b/tests/test_cat_vae.py index b89da842..30ab8032 100644 --- a/tests/test_cat_vae.py +++ b/tests/test_cat_vae.py @@ -1,6 +1,6 @@ import torch import unittest -from models import GUMBELVAE +from models import GumbelVAE from torchsummary import summary @@ -8,7 +8,7 @@ class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) - self.model = GUMBELVAE(3, 10) + self.model = GumbelVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) diff --git a/tests/test_joint_Vae.py b/tests/test_joint_Vae.py new file mode 100644 index 00000000..541c1354 --- /dev/null +++ b/tests/test_joint_Vae.py @@ -0,0 +1,38 @@ +import torch +import unittest +from models import JointVAE +from torchsummary import summary + + +class TestVAE(unittest.TestCase): + + def setUp(self) -> None: + # self.model2 = VAE(3, 10) + self.model = JointVAE(3, 10, 40, 0.0) + + def test_summary(self): + print(summary(self.model, (3, 64, 64), device='cpu')) + # print(summary(self.model2, (3, 64, 64), device='cpu')) + + def test_forward(self): + x = torch.randn(16, 3, 64, 64) + y = self.model(x) + print("Model Output size:", y[0].size()) + # print("Model2 Output size:", self.model2(x)[0].size()) + + def test_loss(self): + x = torch.randn(128, 3, 64, 64) + + result = self.model(x) + loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5) + print(loss) + + + def test_sample(self): + self.model.cuda() + y = self.model.sample(144, 0) + print(y.shape) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file