forked from AntixK/PyTorch-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
312 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
model_params: | ||
name: 'BetaTCVAE' | ||
in_channels: 3 | ||
latent_dim: 128 | ||
anneal_steps: 100 | ||
alpha: 1. | ||
beta: 0.5 | ||
gamma: 1. | ||
|
||
exp_params: | ||
dataset: celeba | ||
data_path: "../../shared/momo/Data/" | ||
img_size: 64 | ||
batch_size: 144 # Better to have a square number | ||
LR: 0.005 | ||
weight_decay: 0.0 | ||
scheduler_gamma: 0.97 | ||
|
||
trainer_params: | ||
gpus: 1 | ||
max_nb_epochs: 50 | ||
max_epochs: 30 | ||
|
||
logging_params: | ||
save_dir: "logs/" | ||
name: "BetaTCVAE" | ||
manual_seed: 1265 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import torch | ||
from models import BaseVAE | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from .types_ import * | ||
import math | ||
|
||
|
||
class BetaTCVAE(BaseVAE): | ||
num_iter = 0 # Global static variable to keep track of iterations | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
latent_dim: int, | ||
hidden_dims: List = None, | ||
anneal_steps: int = 200, | ||
alpha: float = 1., | ||
beta: float = 6., | ||
gamma: float = 1., | ||
**kwargs) -> None: | ||
super(BetaTCVAE, self).__init__() | ||
|
||
self.latent_dim = latent_dim | ||
self.anneal_steps = anneal_steps | ||
|
||
self.alpha = alpha | ||
self.beta = beta | ||
self.gamma = gamma | ||
|
||
modules = [] | ||
if hidden_dims is None: | ||
hidden_dims = [32, 64, 128, 256, 512] | ||
|
||
# Build Encoder | ||
for h_dim in hidden_dims: | ||
modules.append( | ||
nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels=h_dim, | ||
kernel_size= 3, stride= 2, padding = 1), | ||
nn.BatchNorm2d(h_dim), | ||
nn.LeakyReLU()) | ||
) | ||
in_channels = h_dim | ||
|
||
self.encoder = nn.Sequential(*modules) | ||
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) | ||
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) | ||
|
||
|
||
# Build Decoder | ||
modules = [] | ||
|
||
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) | ||
|
||
hidden_dims.reverse() | ||
|
||
for i in range(len(hidden_dims) - 1): | ||
modules.append( | ||
nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[i], | ||
hidden_dims[i + 1], | ||
kernel_size=3, | ||
stride = 2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[i + 1]), | ||
nn.LeakyReLU()) | ||
) | ||
|
||
|
||
|
||
self.decoder = nn.Sequential(*modules) | ||
|
||
self.final_layer = nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[-1], | ||
hidden_dims[-1], | ||
kernel_size=3, | ||
stride=2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[-1]), | ||
nn.LeakyReLU(), | ||
nn.Conv2d(hidden_dims[-1], out_channels= 3, | ||
kernel_size= 3, padding= 1), | ||
nn.Tanh()) | ||
|
||
def encode(self, input: Tensor) -> List[Tensor]: | ||
""" | ||
Encodes the input by passing through the encoder network | ||
and returns the latent codes. | ||
:param input: (Tensor) Input tensor to encoder [N x C x H x W] | ||
:return: (Tensor) List of latent codes | ||
""" | ||
result = self.encoder(input) | ||
result = torch.flatten(result, start_dim=1) | ||
|
||
# Split the result into mu and var components | ||
# of the latent Gaussian distribution | ||
mu = self.fc_mu(result) | ||
log_var = self.fc_var(result) | ||
|
||
return [mu, log_var] | ||
|
||
def decode(self, z: Tensor) -> Tensor: | ||
""" | ||
Maps the given latent codes | ||
onto the image space. | ||
:param z: (Tensor) [B x D] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
result = self.decoder_input(z) | ||
result = result.view(-1, 512, 2, 2) | ||
result = self.decoder(result) | ||
result = self.final_layer(result) | ||
return result | ||
|
||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | ||
""" | ||
Reparameterization trick to sample from N(mu, var) from | ||
N(0,1). | ||
:param mu: (Tensor) Mean of the latent Gaussian [B x D] | ||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] | ||
:return: (Tensor) [B x D] | ||
""" | ||
std = torch.exp(0.5 * logvar) | ||
eps = torch.randn_like(std) | ||
return eps * std + mu | ||
|
||
def forward(self, input: Tensor, **kwargs) -> List[Tensor]: | ||
mu, log_var = self.encode(input) | ||
z = self.reparameterize(mu, log_var) | ||
return [self.decode(z), input, mu, log_var, z] | ||
|
||
def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor): | ||
""" | ||
Computes the log pdf of the Gaussian with parameters mu and logvar at x | ||
:param x: | ||
:param mu: | ||
:param logvar: | ||
:return: | ||
""" | ||
norm = - 0.5 * (math.log(2 * math.pi) + logvar) | ||
log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar)) | ||
return log_density | ||
|
||
def loss_function(self, | ||
*args, | ||
**kwargs) -> dict: | ||
""" | ||
Computes the VAE loss function. | ||
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} | ||
:param args: | ||
:param kwargs: | ||
:return: | ||
""" | ||
if self.training: | ||
self.num_iter += 1 | ||
|
||
recons = args[0] | ||
input = args[1] | ||
mu = args[2] | ||
log_var = args[3] | ||
z = args[4] | ||
|
||
recons_loss =F.mse_loss(recons, input) | ||
|
||
log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1) | ||
|
||
zeros = torch.zeros_like(z) | ||
log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1) | ||
|
||
batch_size, latent_dim = z.shape | ||
mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim), | ||
mu.view(1, batch_size, latent_dim), | ||
log_var.view(1, batch_size, latent_dim)) | ||
|
||
log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False) | ||
log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1) | ||
|
||
kld_loss = (log_prod_q_z - log_p_z).mean() | ||
tc_loss = (log_q_z - log_prod_q_z).mean() | ||
mi_loss = (log_q_zx - log_q_z).mean() | ||
|
||
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) | ||
|
||
anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1) | ||
loss = recons_loss + \ | ||
self.alpha * mi_loss + \ | ||
self.beta * tc_loss + \ | ||
anneal_rate * self.gamma * kld_loss | ||
|
||
return {'loss': loss, | ||
'Reconstruction_Loss':recons_loss, | ||
'KLD':kld_loss, | ||
'TC_Loss':tc_loss, | ||
'MI_Loss':mi_loss} | ||
|
||
def sample(self, | ||
num_samples:int, | ||
current_device: int, **kwargs) -> Tensor: | ||
""" | ||
Samples from the latent space and return the corresponding | ||
image space map. | ||
:param num_samples: (Int) Number of samples | ||
:param current_device: (Int) Device to run the model | ||
:return: (Tensor) | ||
""" | ||
z = torch.randn(num_samples, | ||
self.latent_dim) | ||
|
||
z = z.to(current_device) | ||
|
||
samples = self.decode(z) | ||
return samples | ||
|
||
def generate(self, x: Tensor, **kwargs) -> Tensor: | ||
""" | ||
Given an input image x, returns the reconstructed image | ||
:param x: (Tensor) [B x C x H x W] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
|
||
return self.forward(x)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
class VanillaVAE(BaseVAE): | ||
|
||
|
||
def __init__(self, | ||
in_channels: int, | ||
latent_dim: int, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
import unittest | ||
from models import BetaTCVAE | ||
from torchsummary import summary | ||
|
||
|
||
class TestBetaTCVAE(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
# self.model2 = VAE(3, 10) | ||
self.model = BetaTCVAE(3, 64, anneal_steps= 100) | ||
|
||
def test_summary(self): | ||
print(summary(self.model, (3, 64, 64), device='cpu')) | ||
# print(summary(self.model2, (3, 64, 64), device='cpu')) | ||
|
||
def test_forward(self): | ||
print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) | ||
x = torch.randn(16, 3, 64, 64) | ||
y = self.model(x) | ||
print("Model Output size:", y[0].size()) | ||
# print("Model2 Output size:", self.model2(x)[0].size()) | ||
|
||
def test_loss(self): | ||
x = torch.randn(16, 3, 64, 64) | ||
|
||
result = self.model(x) | ||
loss = self.model.loss_function(*result, M_N = 0.005) | ||
print(loss) | ||
|
||
def test_sample(self): | ||
self.model.cuda() | ||
y = self.model.sample(8, 'cuda') | ||
print(y.shape) | ||
|
||
def test_generate(self): | ||
x = torch.randn(16, 3, 64, 64) | ||
y = self.model.generate(x) | ||
print(y.shape) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters