Skip to content

Commit 7a16a1f

Browse files
committed
add pretrain code
1 parent 1c1e3c2 commit 7a16a1f

File tree

4 files changed

+334
-4
lines changed

4 files changed

+334
-4
lines changed

feat/dataloader/mini_imagenet_pre.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os.path as osp
2+
import PIL
3+
from PIL import Image
4+
5+
import torch
6+
from torch.utils.data import Dataset
7+
from torchvision import transforms
8+
import numpy as np
9+
10+
# use for miniImageNet pre-train
11+
THIS_PATH = osp.dirname(__file__)
12+
ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..'))
13+
IMAGE_PATH = osp.join(ROOT_PATH, 'data/miniimagenet/images')
14+
SPLIT_PATH = osp.join(ROOT_PATH, 'data/miniimagenet/split')
15+
16+
class MiniImageNet(Dataset):
17+
18+
def __init__(self, setname, args):
19+
csv_path = osp.join(SPLIT_PATH, setname + '.csv')
20+
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
21+
22+
data = []
23+
label = []
24+
lb = -1
25+
26+
self.wnids = []
27+
28+
for l in lines:
29+
name, wnid = l.split(',')
30+
path = osp.join(IMAGE_PATH, name)
31+
if wnid not in self.wnids:
32+
self.wnids.append(wnid)
33+
lb += 1
34+
data.append(path)
35+
label.append(lb)
36+
37+
self.data = data
38+
self.label = label
39+
self.num_class = len(set(label))
40+
41+
if args.model_type == 'conv':
42+
image_size = 84
43+
if setname == 'train':
44+
self.transform = transforms.Compose([
45+
transforms.RandomResizedCrop(image_size),
46+
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
47+
transforms.RandomHorizontalFlip(),
48+
transforms.ToTensor(),
49+
# Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']),
50+
transforms.Normalize(np.array([0.485, 0.456, 0.406]),
51+
np.array([0.229, 0.224, 0.225])),
52+
53+
])
54+
else:
55+
self.transform = transforms.Compose([
56+
transforms.Resize(92),
57+
transforms.CenterCrop(image_size),
58+
transforms.ToTensor(),
59+
transforms.Normalize(np.array([0.485, 0.456, 0.406]),
60+
np.array([0.229, 0.224, 0.225]))
61+
])
62+
else:
63+
# for resNet
64+
image_size = 80
65+
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
66+
std = [x / 255 for x in [63.0, 62.1, 66.7]]
67+
if setname == 'train':
68+
self.transform = transforms.Compose([
69+
# transforms.Resize(92, interpolation = PIL.Image.BICUBIC),
70+
transforms.RandomResizedCrop(image_size),
71+
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
72+
transforms.RandomHorizontalFlip(),
73+
transforms.ToTensor(),
74+
# Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']),
75+
transforms.Normalize(mean, std)])
76+
else:
77+
self.transform = transforms.Compose([
78+
transforms.Resize(92),
79+
transforms.CenterCrop(image_size),
80+
transforms.ToTensor(),
81+
transforms.Normalize(mean, std)])
82+
83+
def __len__(self):
84+
return len(self.data)
85+
86+
def __getitem__(self, i):
87+
path, label = self.data[i], self.label[i]
88+
image = self.transform(Image.open(path).convert('RGB'))
89+
return image, label
90+

feat/models/classifier.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from feat.utils import euclidean_metric
5+
import torch.nn.functional as F
6+
7+
class Classifier(nn.Module):
8+
9+
def __init__(self, args):
10+
super().__init__()
11+
self.args = args
12+
if args.model_type == 'ConvNet':
13+
hdim = 64
14+
from feat.networks.convnet import ConvNet
15+
self.encoder = ConvNet()
16+
elif args.model_type == 'ResNet':
17+
hdim = 640
18+
from feat.networks.resnet import ResNet as ResNet
19+
self.encoder = ResNet()
20+
else:
21+
raise ValueError('')
22+
23+
self.fc = nn.Linear(hdim, args.num_class)
24+
25+
def forward(self, data, is_emb = False):
26+
out = self.encoder(data)
27+
if not is_emb:
28+
out = self.fc(out)
29+
return out
30+
31+
def forward_proto(self, data_shot, data_query, way = None):
32+
if way is None:
33+
way = self.args.num_class
34+
proto = self.encoder(data_shot)
35+
proto = proto.reshape(self.args.shot, way, -1).mean(dim=0)
36+
37+
query = self.encoder(data_query)
38+
logits = euclidean_metric(query, proto)
39+
return logits

feat/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ def set_gpu(x):
1010
print('using gpu:', x)
1111

1212

13-
def ensure_path(path):
13+
def ensure_path(path, remove=True):
1414
if os.path.exists(path):
15-
if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
16-
shutil.rmtree(path)
17-
os.mkdir(path)
15+
if remove:
16+
if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
17+
shutil.rmtree(path)
18+
os.mkdir(path)
1819
else:
1920
os.mkdir(path)
2021

pretrain.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import argparse
2+
import os.path as osp
3+
import shutil
4+
import torch
5+
import torch.nn.functional as F
6+
from torch.utils.data import DataLoader
7+
from feat.models.classifier import Classifier
8+
from feat.dataloader.samplers import CategoriesSampler
9+
from feat.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric
10+
from tensorboardX import SummaryWriter
11+
from tqdm import tqdm
12+
13+
# pre-train backbone
14+
if __name__ == '__main__':
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--batch_size', type=int, default=128)
17+
parser.add_argument('--max_epoch', type=int, default=200)
18+
parser.add_argument('--lr', type=float, default=0.001)
19+
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
20+
parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['MiniImageNet', 'TieredImagenet'])
21+
parser.add_argument('--model_type', type=str, default='ResNet', choices=['ConvNet', 'ResNet'])
22+
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 50, 80], help='Decrease learning rate at these epochs.')
23+
parser.add_argument('--gamma', type=float, default=0.1)
24+
parser.add_argument('--resume', type=bool, default=False)
25+
args = parser.parse_args()
26+
pprint(vars(args))
27+
28+
save_path1 = '-'.join([args.dataset, args.model_type, 'Pre'])
29+
save_path2 = '_'.join([str(args.lr), str(args.gamma)])
30+
args.save_path = osp.join(save_path1, save_path2)
31+
ensure_path(save_path1, remove=False)
32+
ensure_path(args.save_path)
33+
34+
if args.dataset == 'MiniImageNet':
35+
# Handle MiniImageNet
36+
from feat.dataloader.mini_imagenet_pre import MiniImageNet as Dataset
37+
elif args.dataset == 'CUB':
38+
from feat.dataloader.cub import CUB as Dataset
39+
elif args.dataset == 'TieredImagenet':
40+
from feat.dataloader.tiered_imagenet import tieredImageNet as Dataset
41+
else:
42+
raise ValueError('Non-supported Dataset.')
43+
44+
trainset = Dataset('train', args)
45+
train_loader = DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
46+
args.num_class = trainset.num_class
47+
valset = Dataset('val', args)
48+
val_sampler = CategoriesSampler(valset.label, 200, valset.num_class, 1 + 15) # test on 16-way 1-shot
49+
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=8, pin_memory=True)
50+
args.way = valset.num_class
51+
args.shot = 1
52+
53+
# construct model
54+
model = Classifier(args)
55+
if args.model_type == 'ConvNet':
56+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
57+
elif args.model_type == 'ResNet':
58+
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0005)
59+
else:
60+
raise ValueError('No Such Encoder')
61+
criterion = torch.nn.CrossEntropyLoss()
62+
63+
if torch.cuda.is_available():
64+
torch.backends.cudnn.benchmark = True
65+
if args.ngpu > 1:
66+
model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
67+
68+
model = model.cuda()
69+
criterion = criterion.cuda()
70+
71+
def save_model(name):
72+
torch.save(dict(params=model.state_dict()), osp.join(args.save_path, name + '.pth'))
73+
74+
def save_checkpoint(is_best, filename='checkpoint.pth.tar'):
75+
state = {'epoch': epoch + 1,
76+
'args': args,
77+
'state_dict': model.state_dict(),
78+
'trlog': trlog,
79+
'val_acc': trlog['max_acc'],
80+
'optimizer' : optimizer.state_dict(),
81+
'global_count': global_count}
82+
83+
torch.save(state, osp.join(args.save_path, filename))
84+
if is_best:
85+
shutil.copyfile(osp.join(args.save_path, filename), osp.join(args.save_path, 'model_best.pth.tar'))
86+
87+
if args.resume == True:
88+
# load checkpoint
89+
state = torch.load(osp.join(args.save_path, 'model_best.pth.tar'))
90+
init_epoch = state['epoch']
91+
resumed_state = state['state_dict']
92+
# resumed_state = {'module.'+k:v for k,v in resumed_state.items()}
93+
model.load_state_dict(resumed_state)
94+
trlog = state['trlog']
95+
optimizer.load_state_dict(state['optimizer'])
96+
initial_lr = optimizer.param_groups[0]['lr']
97+
global_count = state['global_count']
98+
else:
99+
init_epoch = 1
100+
trlog = {}
101+
trlog['args'] = vars(args)
102+
trlog['train_loss'] = []
103+
trlog['val_loss'] = []
104+
trlog['train_acc'] = []
105+
trlog['val_acc'] = []
106+
trlog['max_acc'] = 0.0
107+
trlog['max_acc_epoch'] = 0
108+
initial_lr = args.lr
109+
global_count = 0
110+
111+
timer = Timer()
112+
writer = SummaryWriter(logdir=args.save_path) # should change to log_dir for previous version tensorboardX
113+
for epoch in range(init_epoch, args.max_epoch + 1):
114+
# refine the step-size
115+
if epoch in args.schedule:
116+
initial_lr *= args.gamma
117+
for param_group in optimizer.param_groups:
118+
param_group['lr'] = initial_lr
119+
120+
model.train()
121+
tl = Averager()
122+
ta = Averager()
123+
124+
for i, batch in enumerate(train_loader, 1):
125+
global_count = global_count + 1
126+
if torch.cuda.is_available():
127+
data, label = [_.cuda() for _ in batch]
128+
label = label.type(torch.cuda.LongTensor)
129+
else:
130+
data, label = batch
131+
label = label.type(torch.LongTensor)
132+
logits = model(data)
133+
loss = criterion(logits, label)
134+
acc = count_acc(logits, label)
135+
writer.add_scalar('data/loss', float(loss), global_count)
136+
writer.add_scalar('data/acc', float(acc), global_count)
137+
print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(epoch, i, len(train_loader), loss.item(), acc))
138+
139+
tl.add(loss.item())
140+
ta.add(acc)
141+
142+
optimizer.zero_grad()
143+
loss.backward()
144+
optimizer.step()
145+
146+
tl = tl.item()
147+
ta = ta.item()
148+
149+
# do not do validation in first 500 epoches
150+
if epoch > 30 or epoch % 5 == 0:
151+
model.eval()
152+
vl = Averager()
153+
va = Averager()
154+
print('best epoch {}, current best val acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
155+
# test performance with Few-Shot
156+
label = torch.arange(valset.num_class).repeat(15)
157+
if torch.cuda.is_available():
158+
label = label.type(torch.cuda.LongTensor)
159+
else:
160+
label = label.type(torch.LongTensor)
161+
with torch.no_grad():
162+
for i, batch in tqdm(enumerate(val_loader, 1)):
163+
if torch.cuda.is_available():
164+
data, _ = [_.cuda() for _ in batch]
165+
else:
166+
data, _ = batch
167+
data_shot, data_query = data[:valset.num_class], data[valset.num_class:] # 16-way test
168+
if args.ngpu > 1:
169+
logits = model.module.forward_proto(data_shot, data_query, valset.num_class)
170+
else:
171+
logits = model.forward_proto(data_shot, data_query, valset.num_class)
172+
loss = F.cross_entropy(logits, label)
173+
acc = count_acc(logits, label)
174+
vl.add(loss.item())
175+
va.add(acc)
176+
177+
vl = vl.item()
178+
va = va.item()
179+
writer.add_scalar('data/val_loss', float(vl), epoch)
180+
writer.add_scalar('data/val_acc', float(va), epoch)
181+
print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))
182+
183+
if va > trlog['max_acc']:
184+
trlog['max_acc'] = va
185+
trlog['max_acc_epoch'] = epoch
186+
save_model('max_acc')
187+
save_checkpoint(True)
188+
189+
trlog['train_loss'].append(tl)
190+
trlog['train_acc'].append(ta)
191+
trlog['val_loss'].append(vl)
192+
trlog['val_acc'].append(va)
193+
save_model('epoch-last')
194+
195+
print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
196+
writer.close()
197+
198+
199+
import pdb
200+
pdb.set_trace()

0 commit comments

Comments
 (0)