This directory contains code for vector quantized-variational autoencoder (VQ-VAE) experiments applied to MNIST.
mnist_experiments.py
compares the following models and training procedures (note that there are not yet necessarily tuned):
- Baseline VQ-VAE.
- Categorical variational distribution.
- Using loss from Roy et al. (2018).
- Adding entropy term to loss.
- Correctly scaling entropy and prior terms in loss.
- Not using straight-through on the prior.
- Adding IAFs.
- Gumbel-Softmax variational distribution.
- Gumbel-Softmax VQ-VAE with IAFs.
- Joint training.
- Two-stage training.
- Still training other parameters after introducing IAF.
- Not training other parameters after introducing IAF.
To train a model locally:
# Run with standard hyperparameters:
# From google-research/
python -m probabilistic_vqvae.mnist_experiments
The following commands will train specific hyperparameter and model settings:
1. Baseline VQ-VAE.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=deterministic
2.1. Categorical variational distribution with loss from Roy et al. (2018).
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=categorical --sum_over_latents=False \
--entropy_scale=0.0 --num_samples=10 --stop_gradient_for_prior=True
2.2. Loss from Roy et al. (2018) plus entropy term.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=categorical --sum_over_latents=False \
--entropy_scale=1.0 --num_samples=10 --stop_gradient_for_prior=True \
--beta=0.05
2.3. Same as above but correctly scaling loss.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=categorical --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --stop_gradient_for_prior=True \
--beta=0.05
2.4. No longer stopping gradient for prior.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=categorical --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --stop_gradient_for_prior=False \
--beta=0.05
2.5. Categorical bottleneck with IAFs.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=categorical --sum_over_latents=False --entropy_scale=1.0 \
--num_samples=10 --beta=0.05 --num_iaf_flows=1 --stop_gradient_for_prior=True \
--stop_training_encoder_after_startup=True --iaf_startup_steps=5000
3. Gumbel-Softmax variational distribution.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=gumbel_softmax --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --beta=0.05
4.1. Gumbel-Softmax variational distribution with joint training and IAFs.
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=gumbel_softmax --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --beta=0.05 --num_iaf_flows=1 \
--iaf_startup_steps=0
4.2.1. Gumbel-Softmax variational distribution with two-stage training (still training encoder after introducing IAFs).
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=gumbel_softmax --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --beta=0.05 --num_iaf_flows=1 \
--iaf_startup_steps=5000 --stop_training_encoder_after_startup=False
4.2.2. Gumbel-Softmax variational distribution with two-stage training (stop training encoder after startup steps).
$ python -m probabilistic_vqvae.mnist_experiments \
--bottleneck_type=gumbel_softmax --sum_over_latents=True \
--entropy_scale=1.0 --num_samples=10 --beta=0.05 --num_iaf_flows=1 \
--iaf_startup_steps=5000 --stop_training_encoder_after_startup=True