From 330681d5b01126be50ee4d64433252591e50452e Mon Sep 17 00:00:00 2001 From: Anand Date: Fri, 14 Feb 2020 18:04:11 +0900 Subject: [PATCH] updated VQ VAE --- README.md | 41 ++++---- configs/vq_vae.yaml | 10 +- experiment.py | 27 ++--- models/vq_vae.py | 236 ++++++++++++++++++++++++++++--------------- run.py | 2 +- tests/test_vq_vae.py | 5 +- 6 files changed, 200 insertions(+), 121 deletions(-) diff --git a/README.md b/README.md index 94157d22..5048b18c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ A collection of Variational AutoEncoders (VAEs) implemented in PyTorch with focus on reproducibility. The aim of this project is to provide a quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) -for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates a radically different architecture. +for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates +a radically different architecture. Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model. ### Requirements @@ -78,23 +79,24 @@ logging_params: -| Model | Paper |Reconstruction | Samples | -|----------------------------------------------------------------------|--------------------------------------------------|---------------|---------| -| VAE ([Code][vae_code], [Config][vae_config]) |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] | -| Conditional VAE ([Code][cvae_code], [Config][cvae_config]) |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] | -| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] | -| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] | -| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] | -| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] | -| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] | -| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] | -| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] | -| MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] | -| Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] | -| Joint VAE ([Code][jointvae_code], [Config][jointvae_config]) |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] | -| Info VAE ([Code][infovae_code], [Config][infovae_config]) |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] | -| LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] | -| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] | +| Model | Paper |Reconstruction | Samples | +|------------------------------------------------------------------------|--------------------------------------------------|---------------|---------| +| VAE ([Code][vae_code], [Config][vae_config]) |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] | +| Conditional VAE ([Code][cvae_code], [Config][cvae_config]) |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] | +| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] | +| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] | +| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] | +| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] | +| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] | +| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] | +| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] | +| MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] | +| Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] | +| Joint VAE ([Code][jointvae_code], [Config][jointvae_config]) |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] | +| Info VAE ([Code][infovae_code], [Config][infovae_config]) |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] | +| LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] | +| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] | +| VQ-VAQ (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** | @@ -168,6 +170,7 @@ doesn't seem to work well. [logcoshvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/logcosh_vae.py [catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py [infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py +[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py [vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml [cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml @@ -184,6 +187,7 @@ doesn't seem to work well. [logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml [catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml [infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml +[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml [1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png [2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png @@ -215,6 +219,7 @@ doesn't seem to work well. [28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png [29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png [30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png +[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png [python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg [python-url]: https://www.python.org/ diff --git a/configs/vq_vae.yaml b/configs/vq_vae.yaml index 43ac5f1c..43163c17 100644 --- a/configs/vq_vae.yaml +++ b/configs/vq_vae.yaml @@ -1,8 +1,8 @@ model_params: name: 'VQVAE' in_channels: 3 - embedding_dim: 128 - num_embeddings: 40 + embedding_dim: 64 + num_embeddings: 512 img_size: 64 beta: 0.25 @@ -11,12 +11,12 @@ exp_params: data_path: "../../shared/momo/Data/" img_size: 64 batch_size: 144 # Better to have a square number - LR: 0.005 + LR: 0.001 weight_decay: 0.0 - scheduler_gamma: 0.95 + scheduler_gamma: 0.0 trainer_params: - gpus: 1 + gpus: [2] max_nb_epochs: 50 max_epochs: 30 diff --git a/experiment.py b/experiment.py index 46c33450..67c78037 100644 --- a/experiment.py +++ b/experiment.py @@ -72,25 +72,28 @@ def sample_images(self): f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" f"recons_{self.logger.name}_{self.current_epoch}.png", normalize=True, - nrow=int(math.sqrt(self.params['batch_size']))) + nrow=12) # vutils.save_image(test_input.data, # f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" # f"real_img_{self.logger.name}_{self.current_epoch}.png", # normalize=True, - # nrow=int(math.sqrt(self.params['batch_size']))) + # nrow=12) - samples = self.model.sample(self.params['batch_size'], - self.curr_device, - labels = test_label).cpu() - vutils.save_image(samples.data, - f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" - f"{self.logger.name}_{self.current_epoch}.png", - normalize=True, - nrow=int(math.sqrt(self.params['batch_size']))) + try: + samples = self.model.sample(144, + self.curr_device, + labels = test_label).cpu() + vutils.save_image(samples.data, + f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" + f"{self.logger.name}_{self.current_epoch}.png", + normalize=True, + nrow=12) + except: + raise RuntimeWarning('No sampler for the VAE model is proviced. Continuing...') - del test_input, recons, samples + del test_input, recons #, samples def configure_optimizers(self): @@ -156,7 +159,7 @@ def val_dataloader(self): split = "test", transform=transform, download=False), - batch_size= self.params['batch_size'], + batch_size= 144, shuffle = True, drop_last=True) else: diff --git a/models/vq_vae.py b/models/vq_vae.py index 7ed0fbc8..f044b9c0 100644 --- a/models/vq_vae.py +++ b/models/vq_vae.py @@ -5,39 +5,122 @@ from .types_ import * class VectorQuantizer(nn.Module): + """ + Reference: + [1] https://github.com/zalandoresearch/pytorch-vq-vae + """ def __init__(self, num_embeddings: int, - embedding_dim: int): + embedding_dim: int, + beta: float = 0.25): super(VectorQuantizer, self).__init__() self.K = num_embeddings self.D = embedding_dim + self.beta = beta self.embedding = nn.Embedding(self.K, self.D) - self.embedding.weight.data.uniform_(1./self.K, 1./self.K) + self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) - def forward(self, input: Tensor): - input = input.permute(0, 2, 3, 1) # [B x D x H x W] -> [B x H x W x D] - input_shape = input.shape - flat_input = input.contiguous().view(-1, self.D) # [BHW x D] + def forward(self, latents: Tensor) -> Tensor: + latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] + latents_shape = latents.shape + flat_latents = latents.view(-1, self.D) # [BHW x D] - # Compute L2 distance between input and embedding weights - dist = torch.sum(flat_input**2, dim = 1, keepdim=True) + \ - torch.sum(self.embedding.weight ** 2, dim = 1) - \ - 2 * torch.matmul(flat_input, self.embedding.weight.t()) # [BHW x K] + # Compute L2 distance between latents and embedding weights + dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - \ + 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] # Get the encoding that has the min distance - encoding_inds = torch.argmin(dist, dim = 1).view(-1, 1) # [BHW, 1] + encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] # Convert to one-hot encodings - device = input.device + device = latents.device encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) - encoding_one_hot.scatter_(1, encoding_inds, 1.) # [BHW x K] + encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] + + # Quantize the latents + quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] + quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] - # Quantize the input - quantized_input = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] - quantized_input = quantized_input.view(input_shape) # [B x H x W x D] + # Compute the VQ Losses + commitment_loss = F.mse_loss(quantized_latents.detach(), latents) + embedding_loss = F.mse_loss(quantized_latents, latents.detach()) + + vq_loss = commitment_loss * self.beta + embedding_loss + + # Add the residue back to the latents + quantized_latents = latents + (quantized_latents - latents).detach() + + return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] + + +# class VectorQuantizer(nn.Module): +# """ +# Reference: +# [1] https://github.com/zalandoresearch/pytorch-vq-vae +# """ +# def __init__(self, +# num_embeddings: int, +# embedding_dim: int, +# beta: float=0.25): +# super(VectorQuantizer, self).__init__() +# +# self.D = embedding_dim +# self.K = num_embeddings +# +# self._embedding = nn.Embedding(self.K, self.D) +# self._embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) +# self.beta = beta +# +# def forward(self, inputs): +# # convert inputs from BCHW -> BHWC +# inputs = inputs.permute(0, 2, 3, 1).contiguous() +# input_shape = inputs.shape +# +# # Flatten input +# flat_input = inputs.view(-1, self.D) +# +# # Calculate distances +# distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) +# + torch.sum(self._embedding.weight ** 2, dim=1) +# - 2 * torch.matmul(flat_input, self._embedding.weight.t())) +# +# # Encoding +# encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) +# encodings = torch.zeros(encoding_indices.shape[0], self.K, device=inputs.device) +# encodings.scatter_(1, encoding_indices, 1) +# +# # Quantize and unflatten +# quantized_input = torch.matmul(encodings, self._embedding.weight).view(input_shape) +# +# # Loss +# commitment_loss = F.mse_loss(quantized_input.detach(), inputs) +# embedding_loss = F.mse_loss(quantized_input, inputs.detach()) +# loss = embedding_loss + self.beta * commitment_loss +# +# # quantized = inputs + (quantized - inputs).detach() +# # avg_probs = torch.mean(encodings, dim=0) +# # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) +# +# # convert quantized from BHWC -> BCHW +# return quantized_input.permute(0, 3, 1, 2).contiguous(), loss + + +class ResidualLayer(nn.Module): - return quantized_input.permute(0, 3, 1, 2) # [B x D x H x W] + def __init__(self, + in_channels: int, + out_channels: int): + super(ResidualLayer, self).__init__() + self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels, + kernel_size=3, padding=1, bias=False), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, + kernel_size=1, bias=False)) + + def forward(self, input: Tensor) -> Tensor: + return input + self.resblock(input) class VQVAE(BaseVAE): @@ -58,8 +141,7 @@ def __init__(self, self.beta = beta modules = [] - if hidden_dims is None: - hidden_dims = [32, 64, 128, 256, 512] + hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: @@ -72,6 +154,8 @@ def __init__(self, ) in_channels = h_dim + modules.append(ResidualLayer(in_channels, in_channels)) + modules.append( nn.Sequential( nn.Conv2d(in_channels, embedding_dim, @@ -79,12 +163,16 @@ def __init__(self, nn.BatchNorm2d(embedding_dim), nn.LeakyReLU()) ) + # modules.append(ResidualLayer(embedding_dim, embedding_dim)) + self.encoder = nn.Sequential(*modules) self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim) # Build Decoder modules = [] + # modules.append(ResidualLayer(embedding_dim, embedding_dim)) + modules.append( nn.Sequential( nn.ConvTranspose2d(embedding_dim, @@ -97,6 +185,8 @@ def __init__(self, nn.LeakyReLU()) ) + modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1])) + hidden_dims.reverse() for i in range(len(hidden_dims) - 1): @@ -111,9 +201,8 @@ def __init__(self, nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) - self.decoder = nn.Sequential(*modules) - self.final_layer = nn.Sequential( + modules.append(nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, @@ -123,8 +212,10 @@ def __init__(self, nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, - kernel_size= 3, padding= 1), - nn.Tanh()) + kernel_size= 3, padding=1), + nn.Tanh())) + + self.decoder = nn.Sequential(*modules) def encode(self, input: Tensor) -> List[Tensor]: """ @@ -145,84 +236,63 @@ def decode(self, z: Tensor) -> Tensor: """ result = self.decoder(z) - result = self.final_layer(result) return result - def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: - """ - Reparameterization trick to sample from N(mu, var) from - N(0,1). - :param mu: (Tensor) Mean of the latent Gaussian [B x D] - :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] - :return: (Tensor) [B x D] - """ - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return eps * std + mu - def forward(self, input: Tensor, **kwargs) -> List[Tensor]: encoding = self.encode(input)[0] - quantized_inputs = self.vq_layer(encoding) - return [self.decode(quantized_inputs), input, quantized_inputs, encoding] + quantized_inputs, vq_loss = self.vq_layer(encoding) + return [self.decode(quantized_inputs), input, vq_loss] def loss_function(self, *args, **kwargs) -> dict: """ - Computes the VAE loss function. - KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] - e = args[2] - Z_e = args[3] + vq_loss = args[2] - recons_loss =F.mse_loss(recons, input) + recons_loss = F.mse_loss(recons, input) - # Compute the VQ Losses - commitment_loss = F.mse_loss(e.detach(), Z_e) - embedding_loss = F.mse_loss(e, Z_e.detach()) - - loss = recons_loss + embedding_loss + self.beta * commitment_loss + loss = recons_loss + vq_loss return {'loss': loss, - 'Reconstruction_Loss':recons_loss, - 'Embedding_Loss':embedding_loss, - 'Commitment_Loss':commitment_loss} - - def sample(self, - num_samples:int, - current_device: Union[int, str], **kwargs) -> Tensor: - """ - Samples from the latent space and return the corresponding - image space map. - :param num_samples: (Int) Number of samples - :param current_device: (Int)/(Str) Device to run the model - :return: (Tensor) - """ - # Get random encoding indices - sample_inds = torch.randint(self.embedding_dim, - (num_samples * self.img_size ** 2, 1), - device = current_device) # [SHW, 1] - - # Convert to corresponding one-hot encodings - sample_one_hot = torch.zeros(sample_inds.size(0), self.num_embeddings).to(current_device) - sample_one_hot.scatter_(1, sample_inds, 1.) # [BHW x K] - - # Quantize the input based on the learned embeddings - quantized_input = torch.matmul(sample_one_hot, - self.vq_layer.embedding.weight) # [BHW, D] - quantized_input = quantized_input.view(num_samples, - self.img_size, - self.img_size, - self.embedding_dim) # [B x H x W x D] - - quantized_input = quantized_input.permute(0, 3, 1, 2) # [B x D x H x W] - - samples = self.decode(quantized_input) - return samples + 'Reconstruction_Loss': recons_loss, + 'VQ_Loss':vq_loss} + + # def sample(self, + # num_samples: int, + # current_device: Union[int, str], **kwargs) -> Tensor: + # """ + # Samples from the latent space and return the corresponding + # image space map. + # :param num_samples: (Int) Number of samples + # :param current_device: (Int)/(Str) Device to run the model + # :return: (Tensor) + # """ + # # Get random encoding indices + # sample_inds = torch.randint(self.embedding_dim, + # (num_samples * self.img_size ** 2, 1)) # [SHW, 1] + # + # # Convert to corresponding one-hot encodings + # sample_one_hot = torch.zeros(sample_inds.size(0), self.num_embeddings) + # sample_one_hot.scatter_(1, sample_inds, 1.) # [BHW x K] + # + # # Quantize the input based on the learned embeddings + # quantized_input = torch.matmul(sample_one_hot, + # self.vq_layer.embedding.weight.detach().cpu()) # [BHW, D] + # quantized_input = quantized_input.view(num_samples, + # self.img_size, + # self.img_size, + # self.embedding_dim) # [B x H x W x D] + # + # quantized_input = input + (quantized_input - input).detach() + # quantized_input = quantized_input.permute(0, 3, 1, 2) # [B x D x H x W] + # + # samples = self.decode(quantized_input.to(current_device)) + # return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ diff --git a/run.py b/run.py index 848da2ce..d95dd657 100644 --- a/run.py +++ b/run.py @@ -47,7 +47,7 @@ log_save_interval=100, train_percent_check=1., val_percent_check=1., - num_sanity_val_steps=0, + num_sanity_val_steps=5, early_stop_callback = False, **config['trainer_params']) diff --git a/tests/test_vq_vae.py b/tests/test_vq_vae.py index 9cffbca2..58dded1d 100644 --- a/tests/test_vq_vae.py +++ b/tests/test_vq_vae.py @@ -8,13 +8,14 @@ class TestMIWAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) - self.model = VQVAE(3, 10, 10) + self.model = VQVAE(3, 64, 512) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): + print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) @@ -29,7 +30,7 @@ def test_loss(self): def test_sample(self): self.model.cuda() - y = self.model.sample(144, 'cuda') + y = self.model.sample(8, 'cuda') print(y.shape) def test_generate(self):