[NEW!] The PyTorch implementation of a general conditional GAN Compression framework is released.
We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.
GAN Compression: Efficient Architectures for Interactive Conditional GANs
Muyang Li, Ji Lin, Yaoyao Ding, Zhijian Liu, Jun-Yan Zhu, and Song Han
MIT, Adobe Research, SJTU
In CVPR 2020.
GAN Compression framework: ① Given a pre-trained teacher generator G', we distill a smaller “once-for-all” student generator G that contains all possible channel numbers through weight sharing. We choose different channel numbers for the student generator G at each training step. ② We then extract many sub-generators from the “once-for-all” generator and evaluate their performance. No retraining is needed, which is the advantage of the “once-for-all” generator. ③ Finally, we choose the best sub-generator given the compression ratio target and performance target (FID or mAP), perform fine-tuning, and obtain the final compressed model.
PyTorch Colab notebook: CycleGAN and pix2pix.
- Linux
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
-
Clone this repo:
git clone [email protected]:mit-han-lab/gan-compression.git cd gan-compression
-
Install PyTorch 1.4 and other dependencies (e.g., torchvision).
- For pip users, please type the command
pip install -r requirements.txt
. - For Conda users, we provide an installation script
./scripts/conda_deps.sh
. Alternatively, you can create a new Conda environment usingconda env create -f environment.yml
.
- For pip users, please type the command
-
Install torchprofile.
pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git
-
Download the CycleGAN dataset (e.g., horse2zebra).
bash datasets/download_cyclegan_dataset.sh horse2zebra
-
Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets. For example,
bash ./datasets/download_real_stat.sh horse2zebra A bash ./datasets/download_real_stat.sh horse2zebra B
-
Download the pre-trained models.
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
-
Test the original full model.
bash scripts/cycle_gan/horse2zebra/test_full.sh
-
Test the compressed model.
bash scripts/cycle_gan/horse2zebra/test_compressed.sh
-
Download the pix2pix dataset (e.g., edges2shoes).
bash ./datasets/download_pix2pix_dataset.sh edges2shoes-r
-
Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,
bash datasets/download_real_stat.sh edges2shoes-r B
-
Download the pre-trained models.
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
-
Test the original full model.
bash scripts/pix2pix/edges2shoes-r/test_full.sh
-
Test the compressed model.
bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script datasets/prepare_cityscapes_dataset.py
to preprocess it. You need to download gtFine_trainvaltest.zip
and leftImg8bit_trainvaltest.zip
and unzip them in the same folder. For example, you may put gtFine
and leftImg8bit
in database/cityscapes-origin
. You could prepare the dataset with the following command:
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path ./dataset/table.txt
You will get a preprocessed dataset in database/cityscapes
and a mapping table (used to compute mAP) in dataset/table.txt
.
To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py
to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:
python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB
If you use this code for your research, please cite our paper.
@inproceedings{li2020gan,
title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2020}
}
Our code is developed based on pytorch-CycleGAN-and-pix2pix and SPADE.
We also thank pytorch-fid for FID computation and drn for mAP computation.