This repository is the official implementation of the paper "FAKD: Feature-Affinity Based Knowledge Distillation for Efficient Image Super-Resolution" from ICIP 2020. In this work, we propose a novel and efficient SR model, name Feature Affinity-based Knowledge Distillation (FAKD), by transferring the structural knowledge of a heavy teacher model to a lightweight student model. To transfer the structural knowledge effectively, FAKD aims to distill the second-order statistical information from feature maps and trains a lightweight student network with low computational and memory cost. Experimental results demonstrate the efficacy of our method and superiority over other knowledge distillation based methods in terms of both quantitative and visual metrics.
Here is the quantitative results (PSNR and SSIM) of RCAN and SAN with and w/o FAKD. Teacher Network (TN) and Student Network (SN) are under the same network architecture, but with different network depth or width.
Note:
- RCAN is from the paper Image Super-Resolution Using Very Deep Residual Channel Attention Networks.
- SAN is from the paper Second-order Attention Network for Single Image Super-resolution.
- python 3.6.9
- pytorch 1.1.0
- skimage 0.15.0
- numpy 1.16.4
- imageio 2.6.1
- matplotlib
- tqdm
We use DIV2K dataset as training set which you can download from here and use four benchmark dataset (Set5, Set14, B100, Urban100) as testing set which you can down from here.
Unpack the tar file and arrange the data directory as follows. Then change the dir_data
argument in the code/option.py
to {DATA_ROOT}
.
${DATA_ROOT}
|-- DIV2K
|-- benchmark
|-- Set5
|-- Set14
|-- B100
|-- Urban100
Download the teacher model from here and place it into folder teacher_checkpoint
.
python train.py --ckp_dir overall_distilation/rcan/SA_x4/ --scale 4 --teacher [RCAN] --model RCAN --alpha 0.5 --feature_loss_used 1 --feature_distilation_type 10*SA --features [1,2,3] --epochs 200 --save_results --chop --patch_size 192
More training scripts can be seen in code/scripts
.
Download the distilled model of RCANx4 from here and test the result.
python test.py --ckp_path <checkpoint path> --TS S --scale 4 --model RCAN --n_resgroups 10 --n_resblocks 6
The code is built on EDSR (Pytorch). We thank the authors for sharing the codes.