Skip to content

Commit

Permalink
Changed VQVAE model
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 14, 2020
1 parent 330681d commit bb1f447
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 125 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
A collection of Variational AutoEncoders (VAEs) implemented in PyTorch with focus on reproducibility. The aim of this project is to provide
a quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates
a radically different architecture.
a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models).
Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model.

### Requirements
Expand Down Expand Up @@ -116,7 +116,7 @@ logging_params:
- [x] InfoVAE
- [x] LogCosh VAE
- [x] SWAE
- [ ] VQVAE (in progress)
- [x] VQVAE
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well)
- [ ] Vamp VAE (Doesn't work well)
Expand Down
6 changes: 3 additions & 3 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def sample_images(self):
try:
samples = self.model.sample(144,
self.curr_device,
labels = test_label).cpu()
vutils.save_image(samples.data,
labels = test_label)
vutils.save_image(samples.cpu().data,
f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
f"{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=12)
except:
raise RuntimeWarning('No sampler for the VAE model is proviced. Continuing...')
pass


del test_input, recons #, samples
Expand Down
2 changes: 1 addition & 1 deletion models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def decode(self, input: Tensor) -> Any:
raise NotImplementedError

def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
raise NotImplementedError
raise RuntimeWarning()

def generate(self, x: Tensor, **kwargs) -> Tensor:
raise NotImplementedError
Expand Down
159 changes: 40 additions & 119 deletions models/vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,59 +54,6 @@ def forward(self, latents: Tensor) -> Tensor:

return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]


# class VectorQuantizer(nn.Module):
# """
# Reference:
# [1] https://github.com/zalandoresearch/pytorch-vq-vae
# """
# def __init__(self,
# num_embeddings: int,
# embedding_dim: int,
# beta: float=0.25):
# super(VectorQuantizer, self).__init__()
#
# self.D = embedding_dim
# self.K = num_embeddings
#
# self._embedding = nn.Embedding(self.K, self.D)
# self._embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
# self.beta = beta
#
# def forward(self, inputs):
# # convert inputs from BCHW -> BHWC
# inputs = inputs.permute(0, 2, 3, 1).contiguous()
# input_shape = inputs.shape
#
# # Flatten input
# flat_input = inputs.view(-1, self.D)
#
# # Calculate distances
# distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
# + torch.sum(self._embedding.weight ** 2, dim=1)
# - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
#
# # Encoding
# encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
# encodings = torch.zeros(encoding_indices.shape[0], self.K, device=inputs.device)
# encodings.scatter_(1, encoding_indices, 1)
#
# # Quantize and unflatten
# quantized_input = torch.matmul(encodings, self._embedding.weight).view(input_shape)
#
# # Loss
# commitment_loss = F.mse_loss(quantized_input.detach(), inputs)
# embedding_loss = F.mse_loss(quantized_input, inputs.detach())
# loss = embedding_loss + self.beta * commitment_loss
#
# # quantized = inputs + (quantized - inputs).detach()
# # avg_probs = torch.mean(encodings, dim=0)
# # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
#
# # convert quantized from BHWC -> BCHW
# return quantized_input.permute(0, 3, 1, 2).contiguous(), loss


class ResidualLayer(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -141,51 +88,59 @@ def __init__(self,
self.beta = beta

modules = []
hidden_dims = [32, 64, 128, 256, 512]
if hidden_dims is None:
hidden_dims = [128, 256]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
kernel_size=4, stride=2, padding=1),
nn.LeakyReLU())
)
in_channels = h_dim

modules.append(ResidualLayer(in_channels, in_channels))
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, in_channels,
kernel_size=3, stride=1, padding=1),
nn.LeakyReLU())
)

for _ in range(6):
modules.append(ResidualLayer(in_channels, in_channels))
modules.append(nn.LeakyReLU())

modules.append(
nn.Sequential(
nn.Conv2d(in_channels, embedding_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(embedding_dim),
kernel_size=1, stride=1),
nn.LeakyReLU())
)
# modules.append(ResidualLayer(embedding_dim, embedding_dim))

self.encoder = nn.Sequential(*modules)

self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim)
self.vq_layer = VectorQuantizer(num_embeddings,
embedding_dim,
self.beta)

# Build Decoder
modules = []
# modules.append(ResidualLayer(embedding_dim, embedding_dim))

modules.append(
nn.Sequential(
nn.ConvTranspose2d(embedding_dim,
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.Conv2d(embedding_dim,
hidden_dims[-1],
kernel_size=3,
stride=1,
padding=1),
nn.LeakyReLU())
)

modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))
for _ in range(6):
modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

modules.append(nn.LeakyReLU())

hidden_dims.reverse()

Expand All @@ -194,26 +149,19 @@ def __init__(self,
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
kernel_size=4,
stride=2,
padding=1),
nn.LeakyReLU())
)

modules.append(nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding=1),
nn.Tanh()))
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
out_channels=3,
kernel_size=4,
stride=2, padding=1),
nn.Tanh()))

self.decoder = nn.Sequential(*modules)

Expand Down Expand Up @@ -262,37 +210,10 @@ def loss_function(self,
'Reconstruction_Loss': recons_loss,
'VQ_Loss':vq_loss}

# def sample(self,
# num_samples: int,
# current_device: Union[int, str], **kwargs) -> Tensor:
# """
# Samples from the latent space and return the corresponding
# image space map.
# :param num_samples: (Int) Number of samples
# :param current_device: (Int)/(Str) Device to run the model
# :return: (Tensor)
# """
# # Get random encoding indices
# sample_inds = torch.randint(self.embedding_dim,
# (num_samples * self.img_size ** 2, 1)) # [SHW, 1]
#
# # Convert to corresponding one-hot encodings
# sample_one_hot = torch.zeros(sample_inds.size(0), self.num_embeddings)
# sample_one_hot.scatter_(1, sample_inds, 1.) # [BHW x K]
#
# # Quantize the input based on the learned embeddings
# quantized_input = torch.matmul(sample_one_hot,
# self.vq_layer.embedding.weight.detach().cpu()) # [BHW, D]
# quantized_input = quantized_input.view(num_samples,
# self.img_size,
# self.img_size,
# self.embedding_dim) # [B x H x W x D]
#
# quantized_input = input + (quantized_input - input).detach()
# quantized_input = quantized_input.permute(0, 3, 1, 2) # [B x D x H x W]
#
# samples = self.decode(quantized_input.to(current_device))
# return samples
def sample(self,
num_samples: int,
current_device: Union[int, str], **kwargs) -> Tensor:
raise Warning('VQVAE sampler is not implemented.')

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Expand Down

0 comments on commit bb1f447

Please sign in to comment.