Skip to content

Commit

Permalink
Added CIFAR10 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 27, 2020
1 parent f71e454 commit 7f9df52
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 49 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ logging_params:
| IWAE (5 Samples) |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |
| DFCVAE |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE (CIFAR10)|[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
<!--| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |-->
Expand All @@ -102,19 +103,21 @@ logging_params:
- [x] IWAE
- [x] WAE-MMD
- [x] Conditional VAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [ ] Gamma VAE (in progress)
- [ ] Beta TC-VAE (in progress)
- [ ] Vamp VAE (in progress)
- [ ] HVAE (VAE with Vamp Prior) (in progress)
- [ ] FactorVAE (in progress)
- [ ] Catagorical VAE (Gumbel-Softmax VAE)
- [ ] InfoVAE
- [ ] TwoStageVAE
- [ ] VAE-GAN
- [ ] VLAE
- [ ] PixelVAE
- [ ] VQVAE
- [ ] StyleVAE
- [ ] Sequential VAE
- [ ] Joint VAE
### Contributing
If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,
Expand Down Expand Up @@ -149,6 +152,8 @@ I would be happy to include your result (along with your config file) in this re
[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_20.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_20.png
[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
Expand Down
3 changes: 2 additions & 1 deletion configs/bhvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ model_params:
Capacity_max_iter: 10000

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -18,7 +19,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/bvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ model_params:
loss_type: 'B'

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -16,7 +17,7 @@ exp_params:
trainer_params:
gpus: [2]
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
12 changes: 6 additions & 6 deletions configs/cat_vae.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
model_params:
name: 'CategoricalVAE'
in_channels: 3
num_classes: 40
latent_dim: 128
categorical_dim: 40 # Equal to Num classes
latent_dim: 256
categorical_dim: 10 # Equal to Num classes
temperature: 0.5
anneal_rate: 3e-5
annela_interval: 100
anneal_rate: 0.00003
anneal_interval: 100

exp_params:
dataset: cifar10
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -19,7 +19,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 250

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/cvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -15,7 +16,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/dfc_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -14,7 +15,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/factorvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model_params:
gamma: 40

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
submodel: 'discriminator'
require_secondary_input: True
Expand All @@ -19,7 +20,7 @@ exp_params:
trainer_params:
gpus: [2]
max_nb_epochs: 30

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/gammavae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ model_params:
prior_rate: 1.

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -18,7 +19,7 @@ trainer_params:
gpus: 1
max_nb_epochs: 50
gradient_clip_val: 0.8

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/hvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ model_params:
pseudo_input_size: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -16,7 +17,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/iwae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -14,7 +15,7 @@ exp_params:
trainer_params:
gpus: [3]
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/mssim_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -14,7 +15,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 30

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -14,7 +15,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/vampvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_params:
latent_dim: 128

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -14,7 +15,7 @@ exp_params:
trainer_params:
gpus: 1
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/wae_mmd_imq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ model_params:
kernel_type: 'imq'

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -16,7 +17,7 @@ exp_params:
trainer_params:
gpus: [2]
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
3 changes: 2 additions & 1 deletion configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ model_params:
kernel_type: 'rbf'

exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
Expand All @@ -16,7 +17,7 @@ exp_params:
trainer_params:
gpus: [1]
max_nb_epochs: 50

max_epochs: 50

logging_params:
save_dir: "logs/"
Expand Down
Loading

0 comments on commit 7f9df52

Please sign in to comment.