Skip to content

Commit

Permalink
Added VQ VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 12, 2020
1 parent 4cac7e7 commit 99a26ef
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 8 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ logging_params:
| Model | Paper |Reconstruction | Samples |
|----------------------------------------------------------------------|-------------------------------------------------|---------------|---------|
|----------------------------------------------------------------------|--------------------------------------------------|---------------|---------|
| VAE ([Code][vae_code], [Config][vae_config]) |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] |
| Conditional VAE ([Code][cvae_code], [Config][cvae_config]) |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] |
| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] |
| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] |
| MIWAE (*K = 5, M = 5*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] |
| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] |
| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
Expand All @@ -114,19 +114,22 @@ logging_params:
- [x] InfoVAE
- [x] LogCosh VAE
- [x] SWAE
- [ ] Ladder VAE (in progress)
- [ ] Gamma VAE (in progress)
- [ ] Vamp VAE (in progress)
- [ ] HVAE (VAE with Vamp Prior) (in progress)
- [ ] VQVAE (in progress)
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well)
- [ ] Vamp VAE (Doesn't work well)
- [ ] Beta TC-VAE
- [ ] PixelVAE
- [ ] VQVAE
### Contributing
If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,
I would be happy to include your result (along with your config file) in this repo, citing your name 😊.
Additionally, if you would like to contribute some models, check out the **TODO** for models that are pending or
doesn't seem to work well.
### License
**Apache License 2.0**
Expand Down
26 changes: 26 additions & 0 deletions configs/vq_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
model_params:
name: 'VQVAE'
in_channels: 3
embedding_dim: 128
num_embeddings: 40
img_size: 64
beta: 0.25

exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 30

logging_params:
save_dir: "logs/"
name: "VQVAE"
manual_seed: 1265
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .logcosh_vae import *
from .swae import *
from .miwae import *
from .vq_vae import *


# Aliases
Expand All @@ -43,4 +44,5 @@
'LVAE':LVAE,
'LogCoshVAE':LogCoshVAE,
'SWAE':SWAE,
'MIWAE':MIWAE}
'MIWAE':MIWAE,
'VQVAE':VQVAE}
234 changes: 234 additions & 0 deletions models/vq_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *

class VectorQuantizer(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int):
super(VectorQuantizer, self).__init__()
self.K = num_embeddings
self.D = embedding_dim

self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(1./self.K, 1./self.K)

def forward(self, input: Tensor):
input = input.permute(0, 2, 3, 1) # [B x D x H x W] -> [B x H x W x D]
input_shape = input.shape
flat_input = input.contiguous().view(-1, self.D) # [BHW x D]

# Compute L2 distance between input and embedding weights
dist = torch.sum(flat_input**2, dim = 1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim = 1) - \
2 * torch.matmul(flat_input, self.embedding.weight.t()) # [BHW x K]

# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim = 1).view(-1, 1) # [BHW, 1]

# Convert to one-hot encodings
device = input.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1.) # [BHW x K]

# Quantize the input
quantized_input = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
quantized_input = quantized_input.view(input_shape) # [B x H x W x D]

return quantized_input.permute(0, 3, 1, 2) # [B x D x H x W]


class VQVAE(BaseVAE):

def __init__(self,
in_channels: int,
embedding_dim: int,
num_embeddings: int,
hidden_dims: List = None,
beta: float = 0.25,
img_size: int = 64,
**kwargs) -> None:
super(VQVAE, self).__init__()

self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.img_size = img_size
self.beta = beta

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

modules.append(
nn.Sequential(
nn.Conv2d(in_channels, embedding_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(embedding_dim),
nn.LeakyReLU())
)
self.encoder = nn.Sequential(*modules)

self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim)

# Build Decoder
modules = []
modules.append(
nn.Sequential(
nn.ConvTranspose2d(embedding_dim,
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU())
)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
return [result]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D x H x W]
:return: (Tensor) [B x C x H x W]
"""

result = self.decoder(z)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
encoding = self.encode(input)[0]
quantized_inputs = self.vq_layer(encoding)
return [self.decode(quantized_inputs), input, quantized_inputs, encoding]

def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
e = args[2]
Z_e = args[3]

recons_loss =F.mse_loss(recons, input)

# Compute the VQ Losses
commitment_loss = F.mse_loss(e.detach(), Z_e)
embedding_loss = F.mse_loss(e, Z_e.detach())

loss = recons_loss + embedding_loss + self.beta * commitment_loss
return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'Embedding_Loss':embedding_loss,
'Commitment_Loss':commitment_loss}

def sample(self,
num_samples:int,
current_device: Union[int, str], **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int)/(Str) Device to run the model
:return: (Tensor)
"""
# Get random encoding indices
sample_inds = torch.randint(self.embedding_dim,
(num_samples * self.img_size ** 2, 1),
device = current_device) # [SHW, 1]

# Convert to corresponding one-hot encodings
sample_one_hot = torch.zeros(sample_inds.size(0), self.num_embeddings).to(current_device)
sample_one_hot.scatter_(1, sample_inds, 1.) # [BHW x K]

# Quantize the input based on the learned embeddings
quantized_input = torch.matmul(sample_one_hot,
self.vq_layer.embedding.weight) # [BHW, D]
quantized_input = quantized_input.view(num_samples,
self.img_size,
self.img_size,
self.embedding_dim) # [B x H x W x D]

quantized_input = quantized_input.permute(0, 3, 1, 2) # [B x D x H x W]

samples = self.decode(quantized_input)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]
42 changes: 42 additions & 0 deletions tests/test_vq_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import unittest
from models import VQVAE
from torchsummary import summary


class TestMIWAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
self.model = VQVAE(3, 10, 10)

def test_summary(self):
print(summary(self.model, (3, 64, 64), device='cpu'))
# print(summary(self.model2, (3, 64, 64), device='cpu'))

def test_forward(self):
x = torch.randn(16, 3, 64, 64)
y = self.model(x)
print("Model Output size:", y[0].size())
# print("Model2 Output size:", self.model2(x)[0].size())

def test_loss(self):
x = torch.randn(16, 3, 64, 64)

result = self.model(x)
loss = self.model.loss_function(*result, M_N = 0.005)
print(loss)

def test_sample(self):
self.model.cuda()
y = self.model.sample(144, 'cuda')
print(y.shape)

def test_generate(self):
x = torch.randn(16, 3, 64, 64)
y = self.model.generate(x)
print(y.shape)


if __name__ == '__main__':
unittest.main()

0 comments on commit 99a26ef

Please sign in to comment.