forked from AntixK/PyTorch-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
261 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
model_params: | ||
name: 'LogCoshVAE' | ||
in_channels: 3 | ||
latent_dim: 128 | ||
alpha: 10.0 | ||
beta: 2.0 | ||
|
||
exp_params: | ||
dataset: celeba | ||
data_path: "../../shared/momo/Data/" | ||
img_size: 64 | ||
batch_size: 144 # Better to have a square number | ||
LR: 0.005 | ||
weight_decay: 0.0 | ||
scheduler_gamma: 0.95 | ||
|
||
trainer_params: | ||
gpus: 1 | ||
max_nb_epochs: 50 | ||
max_epochs: 50 | ||
|
||
logging_params: | ||
save_dir: "logs/" | ||
name: "LogCoshVAE" | ||
manual_seed: 1265 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from models import BaseVAE | ||
from torch import nn | ||
from .types_ import * | ||
|
||
|
||
class LogCoshVAE(BaseVAE): | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
latent_dim: int, | ||
hidden_dims: List = None, | ||
alpha: float = 100., | ||
beta: float = 10., | ||
**kwargs) -> None: | ||
super(LogCoshVAE, self).__init__() | ||
|
||
self.latent_dim = latent_dim | ||
self.alpha = alpha | ||
self.beta = beta | ||
|
||
modules = [] | ||
if hidden_dims is None: | ||
hidden_dims = [32, 64, 128, 256, 512] | ||
|
||
# Build Encoder | ||
for h_dim in hidden_dims: | ||
modules.append( | ||
nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels=h_dim, | ||
kernel_size= 3, stride= 2, padding = 1), | ||
nn.BatchNorm2d(h_dim), | ||
nn.LeakyReLU()) | ||
) | ||
in_channels = h_dim | ||
|
||
self.encoder = nn.Sequential(*modules) | ||
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) | ||
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) | ||
|
||
|
||
# Build Decoder | ||
modules = [] | ||
|
||
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) | ||
|
||
hidden_dims.reverse() | ||
|
||
for i in range(len(hidden_dims) - 1): | ||
modules.append( | ||
nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[i], | ||
hidden_dims[i + 1], | ||
kernel_size=3, | ||
stride = 2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[i + 1]), | ||
nn.LeakyReLU()) | ||
) | ||
|
||
self.decoder = nn.Sequential(*modules) | ||
|
||
self.final_layer = nn.Sequential( | ||
nn.ConvTranspose2d(hidden_dims[-1], | ||
hidden_dims[-1], | ||
kernel_size=3, | ||
stride=2, | ||
padding=1, | ||
output_padding=1), | ||
nn.BatchNorm2d(hidden_dims[-1]), | ||
nn.LeakyReLU(), | ||
nn.Conv2d(hidden_dims[-1], out_channels= 3, | ||
kernel_size= 3, padding= 1), | ||
nn.Tanh()) | ||
|
||
def encode(self, input: Tensor) -> List[Tensor]: | ||
""" | ||
Encodes the input by passing through the encoder network | ||
and returns the latent codes. | ||
:param input: (Tensor) Input tensor to encoder [N x C x H x W] | ||
:return: (Tensor) List of latent codes | ||
""" | ||
result = self.encoder(input) | ||
result = torch.flatten(result, start_dim=1) | ||
|
||
# Split the result into mu and var components | ||
# of the latent Gaussian distribution | ||
mu = self.fc_mu(result) | ||
log_var = self.fc_var(result) | ||
|
||
return [mu, log_var] | ||
|
||
def decode(self, z: Tensor) -> Tensor: | ||
""" | ||
Maps the given latent codes | ||
onto the image space. | ||
:param z: (Tensor) [B x D] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
result = self.decoder_input(z) | ||
result = result.view(-1, 512, 2, 2) | ||
result = self.decoder(result) | ||
result = self.final_layer(result) | ||
return result | ||
|
||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | ||
""" | ||
Reparameterization trick to sample from N(mu, var) from | ||
N(0,1). | ||
:param mu: (Tensor) Mean of the latent Gaussian [B x D] | ||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] | ||
:return: (Tensor) [B x D] | ||
""" | ||
std = torch.exp(0.5 * logvar) | ||
eps = torch.randn_like(std) | ||
return eps * std + mu | ||
|
||
def forward(self, input: Tensor, **kwargs) -> List[Tensor]: | ||
mu, log_var = self.encode(input) | ||
z = self.reparameterize(mu, log_var) | ||
return [self.decode(z), input, mu, log_var] | ||
|
||
def loss_function(self, | ||
*args, | ||
**kwargs) -> dict: | ||
""" | ||
Computes the VAE loss function. | ||
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} | ||
:param args: | ||
:param kwargs: | ||
:return: | ||
""" | ||
recons = args[0] | ||
input = args[1] | ||
mu = args[2] | ||
log_var = args[3] | ||
|
||
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset | ||
t = recons - input | ||
# recons_loss = F.mse_loss(recons, input) | ||
# cosh = torch.cosh(self.alpha * t) | ||
# recons_loss = (1./self.alpha * torch.log(cosh)).mean() | ||
|
||
recons_loss = self.alpha * t + \ | ||
torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \ | ||
torch.log(torch.tensor(2.0)) | ||
# print(self.alpha* t.max(), self.alpha*t.min()) | ||
recons_loss = (1. / self.alpha) * recons_loss.mean() | ||
|
||
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) | ||
|
||
loss = recons_loss + self.beta * kld_weight * kld_loss | ||
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} | ||
|
||
def sample(self, | ||
num_samples:int, | ||
current_device: int, **kwargs) -> Tensor: | ||
""" | ||
Samples from the latent space and return the corresponding | ||
image space map. | ||
:param num_samples: (Int) Number of samples | ||
:param current_device: (Int) Device to run the model | ||
:return: (Tensor) | ||
""" | ||
z = torch.randn(num_samples, | ||
self.latent_dim) | ||
|
||
z = z.to(current_device) | ||
|
||
samples = self.decode(z) | ||
return samples | ||
|
||
def generate(self, x: Tensor, **kwargs) -> Tensor: | ||
""" | ||
Given an input image x, returns the reconstructed image | ||
:param x: (Tensor) [B x C x H x W] | ||
:return: (Tensor) [B x C x H x W] | ||
""" | ||
|
||
return self.forward(x)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import unittest | ||
from models import LogCoshVAE | ||
from torchsummary import summary | ||
|
||
|
||
class TestVAE(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
# self.model2 = VAE(3, 10) | ||
self.model = LogCoshVAE(3, 10, alpha=10) | ||
|
||
def test_summary(self): | ||
print(summary(self.model, (3, 64, 64), device='cpu')) | ||
# print(summary(self.model2, (3, 64, 64), device='cpu')) | ||
|
||
def test_forward(self): | ||
x = torch.randn(16, 3, 64, 64) | ||
y = self.model(x) | ||
print("Model Output size:", y[0].size()) | ||
# print("Model2 Output size:", self.model2(x)[0].size()) | ||
|
||
def test_loss(self): | ||
x = torch.rand(16, 3, 64, 64) | ||
|
||
result = self.model(x) | ||
loss = self.model.loss_function(*result, M_N = 0.005) | ||
print(loss) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |