Skip to content

A Collection of Variational Autoencoders (VAE) in PyTorch.

License

Notifications You must be signed in to change notification settings

genhao3/PyTorch-VAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch VAE

A collection of Variational AutoEncoders (VAEs) implemented in PyTorch with focus on reproducibility. The aim of this project is to provide a quick and simple working example for many of the cool VAE models out there. All the models are trained on the CelebA dataset for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates a radically different architecture. Here are the results of each model.

Requirements

  • Python >= 3.5
  • PyTorch >= 1.3
  • Pytorch Lightning >= 0.6.0 (GitHub Repo)
  • CUDA enabled computing device

Installation

$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt

Usage

$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>

Config file template

model_params:
  name: "<name of VAE model>"
  in_channels: 3
  latent_dim: 
    .         # Other parameters required by the model
    .
    .

exp_params:
  data_path: "<path to the celebA dataset>"
  img_size: 64    # Models are designed to work for this size
  batch_size: 64  # Better to have a square number
  LR: 0.005
  weight_decay:
    .         # Other arguments required for training, like scheduler etc.
    .
    .

trainer_params:
  gpus: 1         
  max_nb_epochs: 50
  gradient_clip_val: 0.005
    .
    .
    .

logging_params:
  save_dir: "logs/"
  name: "<experiment name>"
  manual_seed: 

Results

Model Paper Reconstruction Samples
VAE Link
Conditional VAE Link
WAE - MMD (RBF Kernel) Link
WAE - MMD (IMQ Kernel) Link
Beta-VAE Link
Disentangled Beta-VAE Link
IWAE (5 Samples) Link
DFCVAE Link
MSSIM VAE Link
Categorical VAE Link
Joint VAE Link
Info VAE Link

TODO

  • VanillaVAE
  • Beta VAE
  • DFC VAE
  • MSSIM VAE
  • IWAE
  • WAE-MMD
  • Conditional VAE
  • Categorical VAE (Gumbel-Softmax VAE)
  • Joint VAE
  • Disentangled beta-VAE
  • InfoVAE (in progress)
  • Gamma VAE (in progress)
  • Beta TC-VAE (in progress)
  • Vamp VAE (in progress)
  • HVAE (VAE with Vamp Prior) (in progress)
  • FactorVAE (in progress)
  • TwoStageVAE
  • VAE-GAN
  • VLAE
  • PixelVAE
  • VQVAE
  • StyleVAE
  • Sequential VAE

Contributing

If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file, I would be happy to include your result (along with your config file) in this repo, citing your name 😊.

License

Apache License 2.0

Permissions Limitations Conditions
✔️ Commercial use ❌ Trademark use ⓘ License and copyright notice
✔️ Modification ❌ Liability ⓘ State changes
✔️ Distribution ❌ Warranty
✔️ Patent use
✔️ Private use

Citation

if you wish to cite this repository in your work, use the following

@misc{Subramanian2020,
  author = {Subramanian, A.K},
  title = {PyTorch-VAE},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}}
}

About

A Collection of Variational Autoencoders (VAE) in PyTorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%