Skip to content

Commit 7d7d22c

Browse files
committed
add eval scripts and readme
1 parent 1f18428 commit 7d7d22c

6 files changed

+267
-1
lines changed

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The following packages are required to run the scripts:
1717

1818
- Dataset: please download dataset and put images into the folder data/[name of dataset, miniimagenet or cub]/images
1919

20-
- Pre-trained weights: please download the [pre-trained weights](https://drive.google.com/open?id=14Jn1t9JxH-CxjfWy4JmVpCxkC9cDqqfE) of the encoder if needed
20+
- Pre-trained weights: please download the [pre-trained weights](https://drive.google.com/open?id=14Jn1t9JxH-CxjfWy4JmVpCxkC9cDqqfE) of the encoder if needed. The pre-trained weights can be downloaded by the script download_weight.sh
2121

2222
### Dataset
2323

@@ -80,6 +80,12 @@ to train the 1-shot 5-way FEAT model with ResNet backbone on MiniImageNet:
8080

8181
$ python train_feat.py --lr 0.0001 --temperature 128 --max_epoch 100 --model_type ResNet --dataset MiniImageNet --init_weights ./saves/initialization/miniimagenet/res-pre.pth --shot 1 --way 5 --gpu 0 --balance 10 --step_size 10 --gamma 0.5 --lr_mul 10
8282

83+
### Model Training
84+
The train_xxx.py scripts will evaluate the model with best validation accuracy at last. Meanwhile, a given model can also be evaluated by the eval_xxx.py, with options similar to the training scripts. For example, for a ConvNet model at "./saves/feat-model/xx.pth", it can be evaluated for 1-shot 5-way tasks by:
85+
86+
$ python eval_feat.py --model_type ConvNet --dataset MiniImageNet --model_path ./saves/FEAT-Models/MiniIMageNet-Conv-1-Shot-5-Way.pth --shot 1 --way 5 --gpu 0
87+
88+
We assume the input model is a GPU stored model.
8389

8490
## .bib citation
8591
If this repo helps in your work, please cite the following paper:

data/README.md

+11
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
1+
# Dataset Pre-processing
12

3+
## General Pre-processing
4+
After downloading the dataset, please create a new folder named "images" under the folder "miniimagenet" or "cub", and put all images in this folder. The provided data loader will read images from the "images" folder by default. Of course, it is also OK to change the read path. For example, for the miniimagenet dataset, please change the line 10 of "./feat/dataloader/mini_imagenet.py" as the path of the downloaded images.
5+
6+
We assume all the images in the folder are the original ones (except a crop based on bounding boxes for CUB, see details below), and the data loader will do transformations on those raw images, such as resize and normalization. All the images will be resized as 84x84 for ConNet backbone, and 80x80 for ResNet backbone.
7+
8+
### MiniImageNet
9+
The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the [Ravi's split](https://github.com/twitter/meta-learning-lstm), and use 64 classes as SEEN categories, 16 and 20 as two sets of UNSEEN categories for model validation and evaluation respectively. To download this dataset, please email Sachin Ravi for the link.
10+
11+
### CUB
12+
[Caltech-UCSD Birds (CUB) 200-2011 dataset](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) is initially designed for fine-grained classification. It contains in total 11,788 images of birds over 200 species. On CUB, we follow the [previous setting](https://arxiv.org/abs/1707.02610) randomly sampling 100 species as SEEN classes, another two 50 species are used as two UNSEEN sets. Since there is no public class split for CUB, we use our own split as saved in the "CUB" folder. We crop all images with given bounding box before training. We only test CUB with ConvNet backbone in our work.

download_weight.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env bash
2+
3+
function download() {
4+
printf "\033[32mstart to download $2 in $current/saves/$1\033[0m\n"
5+
[[ -e $1.zip ]] && rm -rf $1.zip
6+
wget -c https://doc-00-7o-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/jautjthvgoh3idbpvifflcpu1uo72846/1545379200000/09560182245773775633/*/1DFYbAta5mcMDtu1uW8f-8PFKsrHrTwJ0?e=download -O $1.zip
7+
if [[ -e $1 ]]; then
8+
backup=$1.`date '+%Y%m%d%H%M%S'`
9+
echo "backup current $1 directory first"
10+
rm -rf $1.*
11+
cp -rf $1 $backup
12+
fi
13+
unzip $1.zip
14+
rm -rf $1.zip
15+
}
16+
17+
# We assume this will run under the FEAT folder
18+
current=$(pwd)
19+
[[ ! -e saves ]] && mkdir -p saves
20+
cd saves
21+
22+
# Download pre-trained weights in "./saves/initialization".
23+
download "initialization" "pre-trained weights"
24+
echo ""
25+
# Download learned models in "./saves/FEAT-Models"
26+
download "FEAT-Models" "learned models"

eval_feat.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
import os.path as osp
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.utils.data import DataLoader
8+
9+
from feat.models.feat import FEAT
10+
from feat.dataloader.samplers import CategoriesSampler
11+
from feat.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, compute_confidence_interval
12+
from tensorboardX import SummaryWriter
13+
14+
15+
if __name__ == '__main__':
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--way', type=int, default=5)
18+
parser.add_argument('--shot', type=int, default=1)
19+
parser.add_argument('--query', type=int, default=15)
20+
parser.add_argument('--model_type', type=str, default='ConvNet', choices=['ConvNet', 'ResNet'])
21+
parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['MiniImageNet', 'CUB'])
22+
parser.add_argument('--model_path', type=str, default=None)
23+
parser.add_argument('--gpu', default='0')
24+
args = parser.parse_args()
25+
args.temperature = 1
26+
pprint(vars(args))
27+
28+
set_gpu(args.gpu)
29+
if args.dataset == 'MiniImageNet':
30+
# Handle MiniImageNet
31+
from feat.dataloader.mini_imagenet import MiniImageNet as Dataset
32+
elif args.dataset == 'CUB':
33+
from feat.dataloader.cub import CUB as Dataset
34+
else:
35+
raise ValueError('Non-supported Dataset.')
36+
37+
model = FEAT(args, dropout = 0.5)
38+
if torch.cuda.is_available():
39+
torch.backends.cudnn.benchmark = True
40+
model = model.cuda()
41+
42+
test_set = Dataset('test', args)
43+
sampler = CategoriesSampler(test_set.label, 10000, args.way, args.shot + args.query)
44+
loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
45+
test_acc_record = np.zeros((10000,))
46+
47+
model.load_state_dict(torch.load(args.model_path)['params'])
48+
model.eval()
49+
50+
ave_acc = Averager()
51+
label = torch.arange(args.way).repeat(args.query)
52+
if torch.cuda.is_available():
53+
label = label.type(torch.cuda.LongTensor)
54+
else:
55+
label = label.type(torch.LongTensor)
56+
57+
for i, batch in enumerate(loader, 1):
58+
if torch.cuda.is_available():
59+
data, _ = [_.cuda() for _ in batch]
60+
else:
61+
data = batch[0]
62+
k = args.way * args.shot
63+
data_shot, data_query = data[:k], data[k:]
64+
logits, _ = model(data_shot, data_query)
65+
acc = count_acc(logits, label)
66+
ave_acc.add(acc)
67+
test_acc_record[i-1] = acc
68+
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
69+
70+
m, pm = compute_confidence_interval(test_acc_record)
71+
print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

eval_matchnet.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import argparse
2+
import os.path as osp
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.utils.data import DataLoader
8+
from feat.dataloader.samplers import CategoriesSampler
9+
from feat.models.matchnet import MatchNet
10+
from feat.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, compute_confidence_interval
11+
from tensorboardX import SummaryWriter
12+
13+
if __name__ == '__main__':
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument('--way', type=int, default=5)
16+
parser.add_argument('--shot', type=int, default=1)
17+
parser.add_argument('--query', type=int, default=15)
18+
parser.add_argument('--use_bilstm', type=bool, default=False)
19+
parser.add_argument('--model_type', type=str, default='ConvNet', choices=['ConvNet', 'ResNet'])
20+
parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['MiniImageNet', 'CUB'])
21+
parser.add_argument('--model_path', type=str, default=None)
22+
parser.add_argument('--gpu', default='0')
23+
args = parser.parse_args()
24+
args.temperature = 1
25+
pprint(vars(args))
26+
27+
set_gpu(args.gpu)
28+
29+
if args.dataset == 'MiniImageNet':
30+
# Handle MiniImageNet
31+
from feat.dataloader.mini_imagenet import MiniImageNet as Dataset
32+
elif args.dataset == 'CUB':
33+
from feat.dataloader.cub import CUB as Dataset
34+
else:
35+
raise ValueError('Non-supported Dataset.')
36+
37+
model = MatchNet(args)
38+
if torch.cuda.is_available():
39+
torch.backends.cudnn.benchmark = True
40+
model = model.cuda()
41+
test_set = Dataset('test', args)
42+
sampler = CategoriesSampler(test_set.label, 10000, args.way, args.shot + args.query)
43+
loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
44+
test_acc_record = np.zeros((10000,))
45+
46+
model.load_state_dict(torch.load(args.model_path)['params'])
47+
model.eval()
48+
49+
ave_acc = Averager()
50+
label = torch.arange(args.way).repeat(args.query)
51+
if torch.cuda.is_available():
52+
label = label.type(torch.cuda.LongTensor)
53+
else:
54+
label = label.type(torch.LongTensor)
55+
56+
label_support = torch.arange(args.way).repeat(args.shot)
57+
label_support = label_support.type(torch.LongTensor)
58+
# transform to one-hot form
59+
label_support_onehot = torch.zeros(args.way * args.shot, args.way)
60+
label_support_onehot.scatter_(1, label_support.unsqueeze(1), 1)
61+
if torch.cuda.is_available():
62+
label_support_onehot = label_support_onehot.cuda() # KN x N
63+
64+
for i, batch in enumerate(loader, 1):
65+
if torch.cuda.is_available():
66+
data, _ = [_.cuda() for _ in batch]
67+
else:
68+
data = batch[0]
69+
k = args.way * args.shot
70+
data_shot, data_query = data[:k], data[k:]
71+
logits = model(data_shot, data_query) # KqN x KN x 1
72+
# use logits to weights all labels, KN x N
73+
prediction = torch.sum(torch.mul(logits, label_support_onehot.unsqueeze(0)), 1) # KqN x N
74+
acc = count_acc(prediction, label)
75+
ave_acc.add(acc)
76+
test_acc_record[i-1] = acc
77+
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
78+
79+
80+
m, pm = compute_confidence_interval(test_acc_record)
81+
print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

eval_protonet.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
import os.path as osp
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.utils.data import DataLoader
8+
from feat.dataloader.samplers import CategoriesSampler
9+
from feat.models.protonet import ProtoNet
10+
from feat.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, compute_confidence_interval
11+
from tensorboardX import SummaryWriter
12+
13+
14+
if __name__ == '__main__':
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--shot', type=int, default=1)
17+
parser.add_argument('--query', type=int, default=15)
18+
parser.add_argument('--way', type=int, default=5)
19+
parser.add_argument('--model_type', type=str, default='ConvNet', choices=['ConvNet', 'ResNet'])
20+
parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['MiniImageNet', 'CUB'])
21+
parser.add_argument('--model_path', type=str, default=None)
22+
parser.add_argument('--gpu', default='0')
23+
args = parser.parse_args()
24+
args.temperature = 1 # we set temperature = 1 during test since it does not influence the results
25+
pprint(vars(args))
26+
27+
set_gpu(args.gpu)
28+
29+
if args.dataset == 'MiniImageNet':
30+
# Handle MiniImageNet
31+
from feat.dataloader.mini_imagenet import MiniImageNet as Dataset
32+
elif args.dataset == 'CUB':
33+
from feat.dataloader.cub import CUB as Dataset
34+
else:
35+
raise ValueError('Non-supported Dataset.')
36+
37+
model = ProtoNet(args)
38+
if torch.cuda.is_available():
39+
torch.backends.cudnn.benchmark = True
40+
model = model.cuda()
41+
test_set = Dataset('test', args)
42+
sampler = CategoriesSampler(test_set.label, 10000, args.way, args.shot + args.query)
43+
loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
44+
test_acc_record = np.zeros((10000,))
45+
46+
model.load_state_dict(torch.load(args.model_path)['params'])
47+
model.eval()
48+
49+
ave_acc = Averager()
50+
label = torch.arange(args.way).repeat(args.query)
51+
if torch.cuda.is_available():
52+
label = label.type(torch.cuda.LongTensor)
53+
else:
54+
label = label.type(torch.LongTensor)
55+
56+
for i, batch in enumerate(loader, 1):
57+
if torch.cuda.is_available():
58+
data, _ = [_.cuda() for _ in batch]
59+
else:
60+
data = batch[0]
61+
k = args.way * args.shot
62+
data_shot, data_query = data[:k], data[k:]
63+
64+
logits = model(data_shot, data_query)
65+
acc = count_acc(logits, label)
66+
ave_acc.add(acc)
67+
test_acc_record[i-1] = acc
68+
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
69+
70+
m, pm = compute_confidence_interval(test_acc_record)
71+
print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

0 commit comments

Comments
 (0)