Skip to content

Commit

Permalink
Updated results
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 28, 2020
1 parent f658659 commit db41da7
Show file tree
Hide file tree
Showing 15 changed files with 38 additions and 16 deletions.
4 changes: 2 additions & 2 deletions LICENSE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright Anand Krishnamoorthy Subramanian
[email protected]
Copyright Anand Krishnamoorthy Subramanian 2020
[email protected]

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

Expand Down
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ logging_params:
| WAE - MMD (RBF Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] |
| IWAE (5 Samples) |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |
| DFCVAE |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
| Info VAE |[Link](https://arxiv.org/abs/1706.02262) | ![][22] | ![][21] |
| Info VAE |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
<!--| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |-->
Expand All @@ -107,6 +108,7 @@ logging_params:
- [x] Conditional VAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [ ] InfoVAE (in progress)
- [ ] Gamma VAE (in progress)
- [ ] Beta TC-VAE (in progress)
Expand Down Expand Up @@ -137,6 +139,19 @@ I would be happy to include your result (along with your config file) in this re
| ✔️ Patent use | | |
| ✔️ Private use | | |
### Citation
if you wish to cite this repository in your work, use the following
```
@misc{Subramanian2020,
author = {Subramanian, A.K},
title = {PyTorch-VAE},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}}
}
```
-----------

[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png
Expand All @@ -145,8 +160,8 @@ I would be happy to include your result (along with your config file) in this re
[4]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_RBF_19.png
[5]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_IMQ_15.png
[6]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_IMQ_15.png
[7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_20.png
[8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_20.png
[7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_H_20.png
[8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_H_20.png
[9]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/IWAE_19.png
[10]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_IWAE_19.png
[11]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DFCVAE_49.png
Expand All @@ -159,6 +174,10 @@ I would be happy to include your result (along with your config file) in this re
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png
[21]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_11.png
[22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_11.png
[23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_7.png
[24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_7.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file added assets/BetaVAE_B_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Binary file added assets/InfoVAE_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_BetaVAE_B_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Binary file added assets/recons_InfoVAE_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions configs/bhvae.yaml → configs/bbvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
loss_type: 'H'
loss_type: 'B'
gamma: 10.0
max_capacity: 25
Capacity_max_iter: 10000
Expand All @@ -23,5 +23,5 @@ trainer_params:

logging_params:
save_dir: "logs/"
name: "BetaVAE_H"
name: "BetaVAE_B"
manual_seed: 1265
4 changes: 2 additions & 2 deletions configs/bvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model_params:
in_channels: 3
latent_dim: 128
beta: 10
loss_type: 'B'
loss_type: 'H'

exp_params:
dataset: celeba
Expand All @@ -21,5 +21,5 @@ trainer_params:

logging_params:
save_dir: "logs/"
name: "BetaVAE_B"
name: "BetaVAE_H"
manual_seed: 1265
5 changes: 3 additions & 2 deletions configs/infovae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ model_params:
name: 'InfoVAE'
in_channels: 3
latent_dim: 128
reg_weight: 110 # Lambda factor
reg_weight: 110 # MMD weight
kernel_type: 'imq'
alpha: -9.0
alpha: -9.0 # KLD weight
beta: 10.5 # Reconstruction weight

exp_params:
dataset: celeba
Expand Down
2 changes: 1 addition & 1 deletion configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 1000
reg_weight: 5000
kernel_type: 'rbf'

exp_params:
Expand Down
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def loss_function(self,

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

if self.loss_type == 'B':
if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
loss = recons_loss + self.beta * kld_weight * kld_loss
elif self.loss_type == 'H':
elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
self.C_max = self.C_max.to(input.device)
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
Expand Down
4 changes: 3 additions & 1 deletion models/info_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self,
latent_dim: int,
hidden_dims: List = None,
alpha: float = -0.5,
beta: float = 5.0,
reg_weight: int = 100,
kernel_type: str = 'imq',
latent_var: float = 2.,
Expand All @@ -26,6 +27,7 @@ def __init__(self,
assert alpha <= 0, 'alpha must be negative or zero.'

self.alpha = alpha
self.beta = beta

modules = []
if hidden_dims is None:
Expand Down Expand Up @@ -140,7 +142,7 @@ def loss_function(self,
mmd_loss = self.compute_mmd(z)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

loss = recons_loss + \
loss = self.beta * recons_loss + \
(1. - self.alpha) * kld_weight * kld_loss + \
(self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}
Expand Down

0 comments on commit db41da7

Please sign in to comment.