The code repository for "Learning Embedding Adaptation for Few-Shot Learning" in PyTorch
Few-shot learning methods address this challenge by learning an instance embedding function from seen classes, and apply the function to instances from unseen classes with limited labels. This style of transfer learning is task-agnostic: the embedding function is not learned optimally discriminative with respect to the unseen classes, where discerning among them is the target task. In this work, we propose a novel approach to adapt the embedding model to the target classification task, yielding embeddings that are task-specific and are discriminative. To this end, we employ a type of self-attention mechanism called Transformer to transform the embeddings from task-agnostic to task-specific by focusing on relating instances from the test instances to the training instances in both seen and unseen classes.
The following packages are required to run the scripts:
-
Package tensorboardX
-
Dataset: please download dataset and put images into the folder data/[name of dataset, miniimagenet or cub]/images
-
Pre-trained weights: please download the pre-trained weights of the encoder if needed
The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the previous setup, and use 64 classes as SEEN categories, 16 and 20 as two sets of UNSEEN categories for model validation and evaluation respectively.
Caltech-UCSD Birds (CUB) 200-2011 dataset is initially designed for fine-grained classification. It contains in total 11,788 images of birds over 200 species. On CUB, we randomly sampled 100 species as SEEN classes, another two 50 species are used as two UNSEEN sets. We crop all images with given bounding box before training. We only test CUB with ConvNet backbone in our work.
We implement two baseline approaches in this repo, i.e., the Matching Network and Prototypical Network. To train the them on this task, cd into this repo's root folder and execute:
$ python train_matchnet.py (or python train_protonet.py)
The train_matchnet.py takes the following command line options:
-
max_epoch
: The maximum number of training epochs, default to200
-
way
: The number of classes in a few-shot task, default to5
-
shot
: Number of instances in each class in a few-shot task, default to1
-
query
: Number of instances in each class to evaluate the performance in both meta-training and meta-test stages, default to15
-
lr
: Learning rate for the model, default to0.0001
with pre-trained model -
step_size
: StepLR learning rate scheduler step, default to20
-
gamma
: StepLR learning rate ratio, default to0.2
-
temperature
: Temperature over the logits, we divide logits with this value, default to1
-
model_type
: Two types of encoder, i.e., the convolution network and ResNet, default toConvNet
-
dataset
: Option for the dataset (MiniImageNet or CUB), default toMiniImageNet
-
init_weights
: The path to the initial weights, default toNone
-
gpu
: The index of GPU to use, default to0
-
use_bilstm
: This is specially designed for Matching Network. If this is true, bi-LSTM is used for embedding adaptation. Default toFalse
-
lr_mul
: This is specially designed for Matching Network with bi-LSTM and FEAT. The learning rate for the top layer will be multiplied by this value (usually with faster learning rate). Default to10
Running the command without arguments will train the models with the default hyperparamters values. Loss changes will be recorded as a tensorboard file in the ./runs folder.
For FEAT, the embedding of all instances in a task is adapted based on the Transformer. The learned model on MiniImageNet and CUB can be found in this link.
-
balance
: This is the weights for the FEAT regularizer. Default to10
$ python train_feat.py
If this repo helps in your work, please cite the following paper:
@article{DBLP:YeHZS2018Learning,
author = {Han-Jia Ye and
Hexiang Hu and
De-Chuan Zhan and
Fei Sha},
title = {Learning Embedding Adaptation for Few-Shot Learning},
journal = {CoRR},
volume = {abs/1812.03664},
year = {2018}
}
We thank following repos providing helpful components/functions in our work.