Skip to content

Commit

Permalink
Added Beta TC VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 20, 2020
1 parent d4978f5 commit 215bb79
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ logging_params:
- [x] LogCosh VAE
- [x] SWAE
- [x] VQVAE
- [ ] Beta TC-VAE (in progress)
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well)
- [ ] Vamp VAE (Doesn't work well)
- [ ] Beta TC-VAE
- [ ] PixelVAE
Expand Down
27 changes: 27 additions & 0 deletions configs/betatc_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model_params:
name: 'BetaTCVAE'
in_channels: 3
latent_dim: 128
anneal_steps: 100
alpha: 1.
beta: 0.5
gamma: 1.

exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.97

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 30

logging_params:
save_dir: "logs/"
name: "BetaTCVAE"
manual_seed: 1265
30 changes: 16 additions & 14 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .swae import *
from .miwae import *
from .vq_vae import *
from .betatc_vae import *


# Aliases
Expand All @@ -27,22 +28,23 @@
CVAE = ConditionalVAE
GumbelVAE = CategoricalVAE

vae_models = {'VanillaVAE':VanillaVAE,
'WAE_MMD':WAE_MMD,
'ConditionalVAE':ConditionalVAE,
'BetaVAE':BetaVAE,
'GammaVAE':GammaVAE,
'HVAE':HVAE,
'VampVAE':VampVAE,
vae_models = {'HVAE':HVAE,
'LVAE':LVAE,
'IWAE':IWAE,
'SWAE':SWAE,
'MIWAE':MIWAE,
'VQVAE':VQVAE,
'DFCVAE':DFCVAE,
'BetaVAE':BetaVAE,
'InfoVAE':InfoVAE,
'WAE_MMD':WAE_MMD,
'VampVAE': VampVAE,
'GammaVAE':GammaVAE,
'MSSIMVAE':MSSIMVAE,
'FactorVAE':FactorVAE,
'CategoricalVAE':CategoricalVAE,
'JointVAE':JointVAE,
'InfoVAE':InfoVAE,
'LVAE':LVAE,
'BetaTCVAE':BetaTCVAE,
'FactorVAE':FactorVAE,
'LogCoshVAE':LogCoshVAE,
'SWAE':SWAE,
'MIWAE':MIWAE,
'VQVAE':VQVAE}
'VanillaVAE':VanillaVAE,
'ConditionalVAE':ConditionalVAE,
'CategoricalVAE':CategoricalVAE}
223 changes: 223 additions & 0 deletions models/betatc_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
import math


class BetaTCVAE(BaseVAE):
num_iter = 0 # Global static variable to keep track of iterations

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
anneal_steps: int = 200,
alpha: float = 1.,
beta: float = 6.,
gamma: float = 1.,
**kwargs) -> None:
super(BetaTCVAE, self).__init__()

self.latent_dim = latent_dim
self.anneal_steps = anneal_steps

self.alpha = alpha
self.beta = beta
self.gamma = gamma

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var, z]

def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
"""
Computes the log pdf of the Gaussian with parameters mu and logvar at x
:param x:
:param mu:
:param logvar:
:return:
"""
norm = - 0.5 * (math.log(2 * math.pi) + logvar)
log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
return log_density

def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
if self.training:
self.num_iter += 1

recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]

recons_loss =F.mse_loss(recons, input)

log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)

zeros = torch.zeros_like(z)
log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)

batch_size, latent_dim = z.shape
mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),
mu.view(1, batch_size, latent_dim),
log_var.view(1, batch_size, latent_dim))

log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)

kld_loss = (log_prod_q_z - log_p_z).mean()
tc_loss = (log_q_z - log_prod_q_z).mean()
mi_loss = (log_q_zx - log_q_z).mean()

# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
loss = recons_loss + \
self.alpha * mi_loss + \
self.beta * tc_loss + \
anneal_rate * self.gamma * kld_loss

return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'KLD':kld_loss,
'TC_Loss':tc_loss,
'MI_Loss':mi_loss}

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)

z = z.to(current_device)

samples = self.decode(z)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]
1 change: 1 addition & 0 deletions models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class VanillaVAE(BaseVAE):


def __init__(self,
in_channels: int,
latent_dim: int,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_betatcvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import unittest
from models import BetaTCVAE
from torchsummary import summary


class TestBetaTCVAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
self.model = BetaTCVAE(3, 64, anneal_steps= 100)

def test_summary(self):
print(summary(self.model, (3, 64, 64), device='cpu'))
# print(summary(self.model2, (3, 64, 64), device='cpu'))

def test_forward(self):
print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
x = torch.randn(16, 3, 64, 64)
y = self.model(x)
print("Model Output size:", y[0].size())
# print("Model2 Output size:", self.model2(x)[0].size())

def test_loss(self):
x = torch.randn(16, 3, 64, 64)

result = self.model(x)
loss = self.model.loss_function(*result, M_N = 0.005)
print(loss)

def test_sample(self):
self.model.cuda()
y = self.model.sample(8, 'cuda')
print(y.shape)

def test_generate(self):
x = torch.randn(16, 3, 64, 64)
y = self.model.generate(x)
print(y.shape)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchsummary import summary


class TestMIWAE(unittest.TestCase):
class TestVQVAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
Expand Down

0 comments on commit 215bb79

Please sign in to comment.