forked from chenxingqiang/PyTorch-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
294 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
model_params: | ||
name: 'CategoricalVAE' | ||
in_channels: 3 | ||
num_classes: 40 | ||
latent_dim: 128 | ||
categorical_dim: 40 # Equal to Num classes | ||
temperature: 0.5 | ||
anneal_rate: 3e-5 | ||
annela_interval: 100 | ||
|
||
exp_params: | ||
data_path: "../../shared/Data/" | ||
img_size: 64 | ||
batch_size: 144 # Better to have a square number | ||
LR: 0.005 | ||
weight_decay: 0.0 | ||
scheduler_gamma: 0.95 | ||
|
||
trainer_params: | ||
gpus: 1 | ||
max_nb_epochs: 50 | ||
|
||
|
||
logging_params: | ||
save_dir: "logs/" | ||
name: "CategoricalVAE" | ||
manual_seed: 1265 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import torch | ||
import numpy as np | ||
from models import BaseVAE | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from .types_ import * | ||
|
||
|
||
class CategoricalVAE(BaseVAE): | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
latent_dim: int, | ||
categorical_dim: int = 40, # Num classes | ||
hidden_dims: List = None, | ||
temperature: float = 0.5, | ||
anneal_rate: float = 3e-5, | ||
annela_interval: int = 100, # every 100 batches | ||
**kwargs) -> None: | ||
super(CategoricalVAE, self).__init__() | ||
|
||
self.latent_dim = latent_dim | ||
self.categorical_dim = categorical_dim | ||
self.temp = temperature | ||
self.min_temp = temperature | ||
self.anneal_rate = anneal_rate | ||
self.anneal_interval = annela_interval | ||
|
||
modules = [] | ||
if hidden_dims is None: | ||
hidden_dims = [32, 64, 128, 256, 512] | ||
|
||
# Build Encoder | ||
for h_dim in hidden_dims: | ||
modules.append( | ||
nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels=h_dim, | ||
kernel_size= 3, stride= 2, padding = 1), | ||
nn.BatchNorm2d(h_dim), | ||
nn.LeakyReLU()) | ||
) | ||
in_channels = h_dim | ||
|
||
self.encoder = nn.Sequential(*modules) | ||
self.fc_z = nn.Linear(hidden_dims[-1]*4, | ||
self.latent_dim * self.categorical_dim) | ||
|
||
# Build Decoder | ||
modules = [] | ||
|
||
self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim | ||
, hidden_dims[-1] * 4) | ||
|
||
hidden_dims.reverse() | ||
|
||
for i in range(len(hidden_dims) - 1): | ||
modules.append( | ||
nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[i], | ||
hidden_dims[i + 1], | ||
kernel_size=3, | ||
stride = 2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[i + 1]), | ||
nn.LeakyReLU()) | ||
) | ||
|
||
|
||
|
||
self.decoder = nn.Sequential(*modules) | ||
|
||
self.final_layer = nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[-1], | ||
hidden_dims[-1], | ||
kernel_size=3, | ||
stride=2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[-1]), | ||
nn.LeakyReLU(), | ||
nn.Conv2d(hidden_dims[-1], out_channels= 3, | ||
kernel_size= 3, padding= 1), | ||
nn.Tanh()) | ||
|
||
def encode(self, input: Tensor) -> List[Tensor]: | ||
""" | ||
Encodes the input by passing through the encoder network | ||
and returns the latent codes. | ||
:param input: (Tensor) Input tensor to encoder [B x C x H x W] | ||
:return: (Tensor) Latent code [B x D x Q] | ||
""" | ||
result = self.encoder(input) | ||
result = torch.flatten(result, start_dim=1) | ||
|
||
# Split the result into mu and var components | ||
# of the latent Gaussian distribution | ||
z = self.fc_z(result) | ||
z = z.view(-1, self.latent_dim, self.categorical_dim) | ||
return [z] | ||
|
||
def decode(self, z: Tensor) -> Tensor: | ||
""" | ||
Maps the given latent codes | ||
onto the image space. | ||
:param z: (Tensor) [B x D] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
result = self.decoder_input(z) | ||
result = result.view(-1, 512, 2, 2) | ||
result = self.decoder(result) | ||
result = self.final_layer(result) | ||
return result | ||
|
||
def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor: | ||
""" | ||
Gumbel-softmax trick to sample from Categorical Distribution | ||
:param z: (Tensor) Latent Codes [B x D x Q] | ||
:return: (Tensor) [B x D] | ||
""" | ||
# Sample from Gumbel | ||
u = torch.rand_like(z) | ||
g = - torch.log(- torch.log(u + eps) + eps) | ||
|
||
# Gumbel-Softmax sample | ||
s = F.softmax((z + g) / self.temp, dim=-1) | ||
s = s.view(-1, self.latent_dim * self.categorical_dim) | ||
return s | ||
|
||
|
||
def forward(self, input: Tensor, **kwargs) -> List[Tensor]: | ||
q = self.encode(input)[0] | ||
z = self.reparameterize(q) | ||
return [self.decode(z), input, q] | ||
|
||
def loss_function(self, | ||
*args, | ||
**kwargs) -> dict: | ||
""" | ||
Computes the VAE loss function. | ||
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} | ||
:param args: | ||
:param kwargs: | ||
:return: | ||
""" | ||
recons = args[0] | ||
input = args[1] | ||
q = args[2] | ||
|
||
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities | ||
|
||
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset | ||
batch_idx = kwargs['batch_idx'] | ||
|
||
# Anneal the temperature at regular intervals | ||
# if batch_idx % self.anneal_interval == 0: | ||
# self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), | ||
# self.min_temp) | ||
|
||
recons_loss =F.mse_loss(recons, input) | ||
|
||
# KL divergence between gumbel-softmax distribution | ||
eps = 1e-7 | ||
|
||
# Entropy of the logits | ||
h1 = q_p * torch.log(q_p + eps) | ||
|
||
# Cross entropy with the categorical distribution | ||
h2 = q_p * np.log(1. / self.categorical_dim + eps) | ||
kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0) | ||
|
||
loss = recons_loss + kld_weight * kld_loss | ||
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} | ||
|
||
def sample(self, | ||
num_samples:int, | ||
current_device: int, **kwargs) -> Tensor: | ||
""" | ||
Samples from the latent space and return the corresponding | ||
image space map. | ||
:param num_samples: (Int) Number of samples | ||
:param current_device: (Int) Device to run the model | ||
:return: (Tensor) | ||
""" | ||
z = torch.randn(num_samples, | ||
self.latent_dim) | ||
|
||
z = z.to(current_device) | ||
|
||
samples = self.decode(z) | ||
return samples | ||
|
||
def generate(self, x: Tensor, **kwargs) -> Tensor: | ||
""" | ||
Given an input image x, returns the reconstructed image | ||
:param x: (Tensor) [B x C x H x W] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
|
||
return self.forward(x)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.