Skip to content

Commit

Permalink
Updated assets
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 20, 2020
1 parent 5658086 commit eeca2b2
Show file tree
Hide file tree
Showing 24 changed files with 247 additions and 36 deletions.
Binary file modified assets/WAE_IMQ_15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/WAE_RBF_17.png
Binary file not shown.
Binary file added assets/WAE_RBF_18.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_WAE_IMQ_15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_WAE_RBF_19.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions configs/wae_mmd_imq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ exp_params:
scheduler_gamma: 0.95

trainer_params:
gpus: [1]
gpus: [2]
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "WassersteinVAE"
name: "WassersteinVAE_IMQ"
manual_seed: 1265


Expand Down
4 changes: 2 additions & 2 deletions configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ exp_params:
scheduler_gamma: 0.95

trainer_params:
gpus: [2]
gpus: [1]
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "WassersteinVAE"
name: "WassersteinVAE_RBF"
manual_seed: 1265


Expand Down
29 changes: 16 additions & 13 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx):
real_img, labels = batch
self.curr_device = real_img.device

results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs)
Expand All @@ -53,18 +55,19 @@ def validation_end(self, outputs):
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def sample_images(self):
z = torch.randn(self.params['batch_size'],
self.model.latent_dim)

if self.on_gpu:
z = z.cuda(self.curr_device)

samples = self.model.decode(z).cpu()

vutils.save_image(samples.data,
f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
# # samples = self.model.sample(self.params['batch_size'], self.curr_device).cpu()
# z = torch.randn(self.params['batch_size'],
# self.model.latent_dim)
#
# if self.on_gpu:
# z = z.cuda(self.curr_device)
#
# samples = self.model.decode(z).cpu()
#
# vutils.save_image(samples.data,
# f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
# normalize=True,
# nrow=int(math.sqrt(self.params['batch_size'])))

# Get sample reconstruction image
test_input, _ = next(iter(self.sample_dataloader))
Expand All @@ -75,7 +78,7 @@ def sample_images(self):
f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
del test_input, recons, samples, z
del test_input, recons #, samples, z



Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .wae_mmd import *
from .cvae import *
from .hvae import *
from .vampvae import *

# Aliases
VAE = VanillaVAE
Expand Down
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/base.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/types_.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/wae_mmd.cpython-36.pyc
Binary file not shown.
3 changes: 3 additions & 0 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def encode(self, input: Tensor) -> List[Tensor]:
def decode(self, input: Tensor) -> Any:
raise NotImplementedError

def sample(self, batch_size:int, current_device: int) -> Tensor:
raise NotImplementedError

@abstractmethod
def forward(self, *inputs: Tensor) -> Tensor:
pass
Expand Down
12 changes: 11 additions & 1 deletion models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,14 @@ def loss_function(self,
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + self.beta * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':kld_loss}
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples
12 changes: 11 additions & 1 deletion models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,14 @@ def loss_function(self,
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples
10 changes: 10 additions & 0 deletions models/gamma_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,13 @@ def loss_function(self,

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss': recons_loss, 'KLD': -kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples
43 changes: 28 additions & 15 deletions models/hvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def __init__(self,

# ========================================================================#
# Pesudo Input for the Vamp-Prior
self.pseudo_input = torch.eye(pseudo_input_size,
requires_grad=False).view(1, 1, pseudo_input_size, -1)


self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,
kernel_size=3, stride=2, padding=1)
# self.pseudo_input = torch.eye(pseudo_input_size,
# requires_grad=False).view(1, 1, pseudo_input_size, -1)
#
#
# self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,
# kernel_size=3, stride=2, padding=1)

def encode_z2(self, input: Tensor) -> List[Tensor]:
"""
Expand Down Expand Up @@ -209,9 +209,6 @@ def loss_function(self,
z1_p_mu = self.recons_z1_mu(z2)
z1_p_log_var = self.recons_z1_log_var(z2)

# Compute the z2 for the pseudo_inputs
x = self.pseudo_layer(self.pseudo_input)
z2_p_mu, z2_p_log_var = self.encode_z2(x)

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
Expand All @@ -225,13 +222,29 @@ def loss_function(self,
dim = 1),
dim = 0)

z2_p_kld = -0.5 * torch.sum(1 + z2_p_log_var - (z2 - z2_p_mu) ** 2 - z2_p_log_var.exp(),
dim = 1)
z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0)

z2_p_kld = torch.logsumexp(z2_p_kld, dim=0)
# z2_p_kld = torch.mean(z2_p_kld, dim = 0)
kld_loss = -(z1_p_kld - z1_kld - z2_kld)
loss = recons_loss + kld_weight * kld_loss
# print(z2_p_kld)

kld_loss = -(z1_p_kld + z2_p_kld - z1_kld - z2_kld)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z2 = torch.randn(batch_size,
self.latent2_dim)

z2 = z2.cuda(current_device)

z1_mu = self.recons_z1_mu(z2)
z1_log_var = self.recons_z1_log_var(z2)
z1 = self.reparameterize(z1_mu, z1_log_var)

debedded_z1 = self.debed_z1_code(z1)
debedded_z2 = self.debed_z2_code(z2)

result = torch.cat([debedded_z1, debedded_z2], dim=1)
result = result.view(-1, 512, 2, 2)
samples = self.decode(result)

return samples
141 changes: 141 additions & 0 deletions models/vampvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class VampVAE(BaseVAE):

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
**kwargs) -> None:
super(VampVAE, self).__init__()

self.latent_dim = latent_dim

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]

def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples
12 changes: 11 additions & 1 deletion models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,14 @@ def loss_function(self,
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples
12 changes: 11 additions & 1 deletion models/wae_mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,14 @@ def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:
mmd = reg_weight * prior_z__kernel.mean() + \
reg_weight * z__kernel.mean() - \
2 * reg_weight * priorz_z__kernel.mean()
return mmd
return mmd

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)

samples = self.model.decode(z)
return samples

0 comments on commit eeca2b2

Please sign in to comment.