Course Project for IFT 6135 - Representation Learning
Project Report link: final_project.pdf
- To train the VQVAE with default arguments as discussed in the report, execute:
source pytorch.venv/bin/activate
python vqvae.py --data-folder /tmp/miniimagenet --output-folder models/vqvae --dataset mnist
- To train the PixelCNN prior on the latents, execute:
python pixelcnn_prior.py --data-folder /tmp/miniimagenet --model models/vqvae --output-folder models/pixelcnn_prior
- MNIST
- FashionMNIST
- CIFAR10
- Mini-ImageNet
- Atari 2600 - Boxing (OpenAI Gym) code
Top 4 rows are Original Images. Bottom 4 rows are Reconstructions.
- We noticed that implementing our own VectorQuantization PyTorch function speeded-up training of VQ-VAE by nearly 3x. The slower, but simpler code is in this commit.
- We added some basic tests for the vector quantization functions (based on
pytest
). To run these tests
py.test . -vv
- Rithesh Kumar
- Tristan Deleu
- Evan Racah