Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification. This repository is made for you if:
- you're new to few-shot learning and want to learn;
- or you're looking for reliable, clear and easily usable code that you can use for your projects.
Don't get lost in large repositories with hundreds of methods and no explanation on how to use them. Here, we want each line of code to be covered by a tutorial.
You want to learn few-shot learning and don't know where to start? Start with our tutorial.
Models:
- AbstractMetaLearner: an abstract class with methods that can be used for any meta-trainable algorithm
- Prototypical Networks
- Matching Networks
- Relation Networks
Tools for data loading:
- EasySet: a ready-to-use Dataset object to handle datasets of images with a class-wise directory split
- TaskSampler: samples batches in the shape of few-shot classification tasks
- CU-Birds: we provide a script to download and extract the dataset, along with (train / val / test) split along classes. The dataset is ready-to-use with EasySet.
- tieredImageNet: we provide the train, val and test specification files to be used by EasySet. To use it, you need the ILSVRC2015 dataset. Once you have downloaded and extracted the dataset, ensure that its localisation on disk is consistent with the class paths specified in the specification files.
- Install the package with pip:
pip install easyfsl
Note: alternatively, you can clone the repository so that you can modify the code as you wish.
- Download CU-Birds and the few-shot train/val/test split:
mkdir -p data/CUB && cd data/CUB
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx" -O images.tgz
rm -rf /tmp/cookies.txt
tar --exclude='._*' -zxvf images.tgz
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/train.json
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/val.json
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/test.json
cd ...
-
Check that you have a 680,9MB
images
folder in./data/CUB
along with three JSON files. -
From the training subset of CUB, create a dataloader that yields few-shot classification tasks:
from easyfsl.data_tools import EasySet, TaskSampler
from torch.utils.data import DataLoader
train_set = EasySet(specs_file="./data/CUB/train.json", training=True)
train_sampler = TaskSampler(
train_set, n_way=5, n_shot=5, n_query=10, n_tasks=40000
)
train_loader = DataLoader(
train_set,
batch_sampler=train_sampler,
num_workers=12,
pin_memory=True,
collate_fn=train_sampler.episodic_collate_fn,
)
- Create and train a model
from easyfsl.methods import PrototypicalNetworks
from torch import nn
from torch.optim import Adam
from torchvision.models import resnet18
convolutional_network = resnet18(pretrained=False)
convolutional_network.fc = nn.Flatten()
model = PrototypicalNetworks(convolutional_network).cuda()
optimizer = Adam(params=model.parameters())
model.fit(train_loader, optimizer)
Note: you can also define a validation data loader and use as an additional argument to fit
in order to use validation during your training.
Troubleshooting: a ResNet18 with a batch size of (5 * (5+10)) = 75 would use about 4.2GB on your GPU.
If you don't have it, switch to CPU, choose a smaller model or reduce the batch size (in TaskSampler
above).
- Evaluate your model on the test set
test_set = EasySet(specs_file="./data/CUB/test.json", training=False)
test_sampler = TaskSampler(
test_set, n_way=5, n_shot=5, n_query=10, n_tasks=100
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=12,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)
accuracy = model.evaluate(test_loader)
print(f"Average accuracy : {(100 * accuracy):.2f}")
- Integrate more methods:
- Matching Networks
- Relation Networks
- MAML
- Transductive Propagation Network
- TADAM
- Integrate non-episodic training
- Integrate more benchmarks:
- tieredImageNet
- miniImageNet
- Meta-Dataset
This project is very open to contributions! You can help in various ways:
- raise issues
- resolve issues already opened
- tackle new features from the roadmap
- fix typos, improve code quality