Skip to content

Commit

Permalink
Added Joint VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 27, 2020
1 parent 7f9df52 commit d282f67
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 14 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ logging_params:
| IWAE (5 Samples) |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |
| DFCVAE |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE (CIFAR10)|[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Categorical VAE |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
<!--| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |-->
Expand All @@ -104,6 +105,7 @@ logging_params:
- [x] WAE-MMD
- [x] Conditional VAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [ ] Gamma VAE (in progress)
- [ ] Beta TC-VAE (in progress)
- [ ] Vamp VAE (in progress)
Expand All @@ -117,7 +119,7 @@ logging_params:
- [ ] VQVAE
- [ ] StyleVAE
- [ ] Sequential VAE
- [ ] Joint VAE
### Contributing
If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,
Expand Down Expand Up @@ -154,7 +156,8 @@ I would be happy to include your result (along with your config file) in this re
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_20.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_20.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_20.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_20.png
[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
11 changes: 6 additions & 5 deletions configs/cat_vae.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
model_params:
name: 'CategoricalVAE'
in_channels: 3
latent_dim: 256
categorical_dim: 10 # Equal to Num classes
latent_dim: 512
categorical_dim: 40
temperature: 0.5
anneal_rate: 0.00003
anneal_interval: 100
alpha: 8.0

exp_params:
dataset: cifar10
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -17,9 +18,9 @@ exp_params:
scheduler_gamma: 0.95

trainer_params:
gpus: 1
gpus: [1]
max_nb_epochs: 50
max_epochs: 250
max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
36 changes: 36 additions & 0 deletions configs/joint_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
model_params:
name: 'JointVAE'
in_channels: 3
latent_dim: 512
categorical_dim: 40
latent_min_capacity: 0.0
latent_max_capacity: 20.0
latent_gamma: 30.
latent_num_iter: 25000
categorical_min_capacity: 0.0
categorical_max_capacity: 20.0
categorical_gamma: 30.
categorical_num_iter: 25000
temperature: 0.5
anneal_rate: 0.00003
anneal_interval: 100
alpha: 20.0

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: [1]
max_nb_epochs: 50
max_epochs: 50

logging_params:
save_dir: "logs/"
name: "JointVAE"
manual_seed: 1265
2 changes: 1 addition & 1 deletion experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def data_transforms(self):
transforms.ToTensor(),
SetRange])

if self.params['dataset'] == 'cifar10':
elif self.params['dataset'] == 'cifar10':
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda img:
Expand Down
6 changes: 4 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from .mssim_vae import MSSIMVAE
from .fvae import *
from .cat_vae import *
from .joint_vae import *

# Aliases
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
GUMBELVAE = CategoricalVAE
GumbelVAE = CategoricalVAE

vae_models = {'VanillaVAE':VanillaVAE,
'WAE_MMD':WAE_MMD,
Expand All @@ -29,4 +30,5 @@
'DFCVAE':DFCVAE,
'MSSIMVAE':MSSIMVAE,
'FactorVAE':FactorVAE,
'CategoricalVAE':CategoricalVAE}
'CategoricalVAE':CategoricalVAE,
'JointVAE':JointVAE}
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
4 changes: 3 additions & 1 deletion models/cat_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self,
temperature: float = 0.5,
anneal_rate: float = 3e-5,
anneal_interval: int = 100, # every 100 batches
alpha: float = 30.,
**kwargs) -> None:
super(CategoricalVAE, self).__init__()

Expand All @@ -25,6 +26,7 @@ def __init__(self,
self.min_temp = temperature
self.anneal_rate = anneal_rate
self.anneal_interval = anneal_interval
self.alpha = alpha

modules = []
if hidden_dims is None:
Expand Down Expand Up @@ -171,7 +173,7 @@ def loss_function(self,
kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)

# kld_weight = 1.2
loss = 30. * recons_loss + kld_weight * kld_loss
loss = self.alpha * recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

def sample(self,
Expand Down
Loading

0 comments on commit d282f67

Please sign in to comment.