Skip to content

Commit

Permalink
Added LogCosh VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 6, 2020
1 parent 09c4451 commit bba5ac5
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 15 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ logging_params:
| Categorical VAE |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
| Info VAE |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] |
| LogCosh VAE |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
Expand All @@ -110,6 +111,7 @@ logging_params:
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [x] InfoVAE
- [x] LogCosh VAE
- [ ] Gamma VAE (in progress)
- [ ] Vamp VAE (in progress)
- [ ] HVAE (VAE with Vamp Prior) (in progress)
Expand Down Expand Up @@ -172,6 +174,8 @@ I would be happy to include your result (along with your config file) in this re
[22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_35.png
[23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_31.png
[24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_31.png
[25]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/LogCoshVAE_7.png
[26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_7.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file added assets/LogCoshVAE_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_LogCoshVAE_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions configs/logcosh_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model_params:
name: 'LogCoshVAE'
in_channels: 3
latent_dim: 128
alpha: 10.0
beta: 2.0

exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 50

logging_params:
save_dir: "logs/"
name: "LogCoshVAE"
manual_seed: 1265
2 changes: 1 addition & 1 deletion configs/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 50
max_epochs: 30

logging_params:
save_dir: "logs/"
Expand Down
18 changes: 7 additions & 11 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def sample_images(self):
f"recons_{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
#
# vutils.save_image(test_input.data,
# f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
# f"real_img_{self.logger.name}_{self.current_epoch}.png",
# normalize=True,
# nrow=int(math.sqrt(self.params['batch_size'])))

vutils.save_image(test_input.data,
f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
f"real_img_{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))

samples = self.model.sample(self.params['batch_size'],
self.curr_device,
Expand All @@ -92,11 +92,6 @@ def sample_images(self):

del test_input, recons, samples

def backward(self, use_amp, loss, optimizer, optimizer_idx):
if self.hold_graph and optimizer_idx == 0:
loss.backward(retain_graph = True)
else:
loss.backward(retain_graph = False)

def configure_optimizers(self):

Expand Down Expand Up @@ -177,6 +172,7 @@ def data_transforms(self):
transforms.CenterCrop(148),
transforms.Resize(self.params['img_size']),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
SetRange])
else:
raise ValueError('Undefined dataset type')
Expand Down
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .info_vae import *
# from .twostage_vae import *
from .lvae import LVAE
from .logcosh_vae import *

# Aliases
VAE = VanillaVAE
Expand All @@ -35,4 +36,5 @@
'FactorVAE':FactorVAE,
'CategoricalVAE':CategoricalVAE,
'JointVAE':JointVAE,
'InfoVAE':InfoVAE}
'InfoVAE':InfoVAE,
'LogCoshVAE':LogCoshVAE}
2 changes: 1 addition & 1 deletion models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def forward(self, *inputs: Tensor) -> Tensor:
pass

@abstractmethod
def loss_function(self, *inputs: Any) -> Tensor:
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
pass


Expand Down
182 changes: 182 additions & 0 deletions models/logcosh_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch
import torch.nn.functional as F
from models import BaseVAE
from torch import nn
from .types_ import *


class LogCoshVAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha: float = 100.,
beta: float = 10.,
**kwargs) -> None:
super(LogCoshVAE, self).__init__()

self.latent_dim = latent_dim
self.alpha = alpha
self.beta = beta

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)

self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]

def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
t = recons - input
# recons_loss = F.mse_loss(recons, input)
# cosh = torch.cosh(self.alpha * t)
# recons_loss = (1./self.alpha * torch.log(cosh)).mean()

recons_loss = self.alpha * t + \
torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
torch.log(torch.tensor(2.0))
# print(self.alpha* t.max(), self.alpha*t.min())
recons_loss = (1. / self.alpha) * recons_loss.mean()

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + self.beta * kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)

z = z.to(current_device)

samples = self.decode(z)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]
7 changes: 6 additions & 1 deletion models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def sample(self,
z = torch.randn(num_samples,
self.latent_dim)

z = z.to(current_device)
z = z.to(current_device) #
# vutils.save_image(test_input.data,
# f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
# f"real_img_{self.logger.name}_{self.current_epoch}.png",
# normalize=True,
# nrow=int(math.sqrt(self.params['batch_size'])))

samples = self.decode(z)
return samples
Expand Down
32 changes: 32 additions & 0 deletions tests/test_logcosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import unittest
from models import LogCoshVAE
from torchsummary import summary


class TestVAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
self.model = LogCoshVAE(3, 10, alpha=10)

def test_summary(self):
print(summary(self.model, (3, 64, 64), device='cpu'))
# print(summary(self.model2, (3, 64, 64), device='cpu'))

def test_forward(self):
x = torch.randn(16, 3, 64, 64)
y = self.model(x)
print("Model Output size:", y[0].size())
# print("Model2 Output size:", self.model2(x)[0].size())

def test_loss(self):
x = torch.rand(16, 3, 64, 64)

result = self.model(x)
loss = self.model.loss_function(*result, M_N = 0.005)
print(loss)


if __name__ == '__main__':
unittest.main()

0 comments on commit bba5ac5

Please sign in to comment.