Skip to content

Commit

Permalink
Added Stratified sampling to Beta TC VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 25, 2020
1 parent 215bb79 commit 3c8fd61
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ logging_params:
manual_seed:
```
**View TensorBoard Logs**
```
$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir tf
```

----
<h2 align="center">
Expand All @@ -87,6 +92,7 @@ logging_params:
| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] |
| Beta-TC-VAE ([Code][btcvae_code], [Config][btcvae_config]) |[Link](https://arxiv.org/abs/1802.04942) | ![][34] | ![][33] |
| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] |
| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] |
| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
Expand Down Expand Up @@ -160,6 +166,7 @@ doesn't seem to work well.
[vae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
[cvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cvae.py
[bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
[btcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/betatc_vae.py
[wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py
[iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py
[miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py
Expand All @@ -176,6 +183,7 @@ doesn't seem to work well.
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml
[bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml
[bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml
[btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml
[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml
[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml
[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml
Expand Down Expand Up @@ -219,7 +227,9 @@ doesn't seem to work well.
[28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png
[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png
[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_1.png
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_20.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_20.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
10 changes: 5 additions & 5 deletions configs/betatc_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ model_params:
name: 'BetaTCVAE'
in_channels: 3
latent_dim: 128
anneal_steps: 100
anneal_steps: 10000
alpha: 1.
beta: 0.5
beta: 6.
gamma: 1.

exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
LR: 0.001
weight_decay: 0.0
scheduler_gamma: 0.97
scheduler_gamma: 0.99

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 30
max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
32 changes: 23 additions & 9 deletions models/betatc_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(self,
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
kernel_size= 4, stride= 2, padding = 1),
nn.LeakyReLU())
)
in_channels = h_dim
Expand All @@ -63,7 +62,6 @@ def __init__(self,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)

Expand All @@ -78,7 +76,6 @@ def __init__(self,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
Expand Down Expand Up @@ -153,15 +150,15 @@ def loss_function(self,
:param kwargs:
:return:
"""
if self.training:
self.num_iter += 1

recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]

weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset

recons_loss =F.mse_loss(recons, input)

log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)
Expand All @@ -174,6 +171,18 @@ def loss_function(self,
mu.view(1, batch_size, latent_dim),
log_var.view(1, batch_size, latent_dim))

# Reference
# [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54
dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
strat_weight = (dataset_size - batch_size - 1) / (dataset_size * (batch_size - 1))
importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
importance_weights.view(-1)[::batch_size] = 1 / dataset_size
importance_weights.view(-1)[1::batch_size] = strat_weight
importance_weights[batch_size - 2, 0] = strat_weight
log_importance_weights = importance_weights.log()

mat_log_q_z += log_importance_weights.unsqueeze(2)

log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)

Expand All @@ -183,11 +192,16 @@ def loss_function(self,

# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
if self.training:
self.num_iter += 1
anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
else:
anneal_rate = 1.

loss = recons_loss + \
self.alpha * mi_loss + \
self.beta * tc_loss + \
anneal_rate * self.gamma * kld_loss
weight * (self.beta * tc_loss +
anneal_rate * self.gamma * kld_loss)

return {'loss': loss,
'Reconstruction_Loss':recons_loss,
Expand Down

0 comments on commit 3c8fd61

Please sign in to comment.