Pytorch implementation for our DropGrad approach. With the proposed regularization method, we can:
- alleviate the overfitting problem in the exisiting gradient-based meta-learning models
- improve the performance under cross-domain few-shot classification setting
Contact: Hung-Yu Tseng ([email protected]), Yi-Wen Chen ([email protected])
Please cite our paper if you find the code or dataset useful for your research.
Regularizing Meta-Learning via Gradient Dropout
Hung-Yu Tseng*, Yi-Wen Chen*, Yi-Hsuan Tsai, Sifei Liu, Yen-Yu Lin, Ming-Hsuan Yang
ArXiv pre-print, 2020 (* equal contribution)
@article{dropgrad,
author = {Tseng, Hung-Yu and Chen, Yi-Wen and Tsai, Yi-Hsuan and Liu, Sifei and Lin, Yen-Yu and Yang, Ming-Hsuan},
title = {Regularizing Meta-Learning via Gradient Dropout},
journal = {arXiv preprint arXiv:2004.05859},
year = {2020}
}
- Python >= 3.5
- Pytorch >= 1.3 and torchvision (https://pytorch.org/)
- You can use the
requirements.txt
file we provide to setup the environment via Anaconda.
conda create --name py36 python=3.6
conda install pytorch torchvision -c pytorch
pip3 install -r requirements.txt
Clone this repository:
git clone https://github.com/hytseng0509/DropGrad.git
cd DropGrad
Download 2 datasets seperately with the following commands.
- Set
DATASET_NAME
to:cub
,miniImagenet
.
cd filelists
python3 process.py DATASET_NAME
cd ..
- Refer to the instruction here for constructing your own dataset.
Train gradient-based model on the mini-ImageNet dataset.
DPMETHOD
: dropout methodnone
,binary
,gaussian
.DPRATE
: dropout rate, we suggest 0.1.
python3 train.py --dropout_method DPMETHOD --dropout_rate DPRATE --name MAML_DPMETHOD_DPRATE --train_aug
Test the model on the mini-ImageNet or CUB (cross-domain) dataset
- Specify
--dataset
tominiImagenet
orcub
python3 test.py --name MAML_DPMETHOD_DPRATE --dataset TESTSET
- This code is built upon the implementation from CloserLookFewShot.
- The dataset, model, and code are for non-commercial research purposes only.
- You can change the number of shot (i.e. 1/5 shots) using the argument
--n_shot
. - Please refer to
output/checkpoints/download_models.py
for the example model file trained with the DropGrad approach.