1
+ import argparse
2
+ import os .path as osp
3
+ import shutil
4
+ import torch
5
+ import torch .nn .functional as F
6
+ from torch .utils .data import DataLoader
7
+ from feat .models .classifier import Classifier
8
+ from feat .dataloader .samplers import CategoriesSampler
9
+ from feat .utils import pprint , set_gpu , ensure_path , Averager , Timer , count_acc , euclidean_metric
10
+ from tensorboardX import SummaryWriter
11
+ from tqdm import tqdm
12
+
13
+ # pre-train backbone
14
+ if __name__ == '__main__' :
15
+ parser = argparse .ArgumentParser ()
16
+ parser .add_argument ('--batch_size' , type = int , default = 128 )
17
+ parser .add_argument ('--max_epoch' , type = int , default = 200 )
18
+ parser .add_argument ('--lr' , type = float , default = 0.001 )
19
+ parser .add_argument ('--ngpu' , type = int , default = 1 , help = '0 = CPU.' )
20
+ parser .add_argument ('--dataset' , type = str , default = 'MiniImageNet' , choices = ['MiniImageNet' , 'TieredImagenet' ])
21
+ parser .add_argument ('--model_type' , type = str , default = 'ResNet' , choices = ['ConvNet' , 'ResNet' ])
22
+ parser .add_argument ('--schedule' , type = int , nargs = '+' , default = [30 , 50 , 80 ], help = 'Decrease learning rate at these epochs.' )
23
+ parser .add_argument ('--gamma' , type = float , default = 0.1 )
24
+ parser .add_argument ('--resume' , type = bool , default = False )
25
+ args = parser .parse_args ()
26
+ pprint (vars (args ))
27
+
28
+ save_path1 = '-' .join ([args .dataset , args .model_type , 'Pre' ])
29
+ save_path2 = '_' .join ([str (args .lr ), str (args .gamma )])
30
+ args .save_path = osp .join (save_path1 , save_path2 )
31
+ ensure_path (save_path1 , remove = False )
32
+ ensure_path (args .save_path )
33
+
34
+ if args .dataset == 'MiniImageNet' :
35
+ # Handle MiniImageNet
36
+ from feat .dataloader .mini_imagenet_pre import MiniImageNet as Dataset
37
+ elif args .dataset == 'CUB' :
38
+ from feat .dataloader .cub import CUB as Dataset
39
+ elif args .dataset == 'TieredImagenet' :
40
+ from feat .dataloader .tiered_imagenet import tieredImageNet as Dataset
41
+ else :
42
+ raise ValueError ('Non-supported Dataset.' )
43
+
44
+ trainset = Dataset ('train' , args )
45
+ train_loader = DataLoader (dataset = trainset , batch_size = args .batch_size , shuffle = True , num_workers = 8 , pin_memory = True )
46
+ args .num_class = trainset .num_class
47
+ valset = Dataset ('val' , args )
48
+ val_sampler = CategoriesSampler (valset .label , 200 , valset .num_class , 1 + 15 ) # test on 16-way 1-shot
49
+ val_loader = DataLoader (dataset = valset , batch_sampler = val_sampler , num_workers = 8 , pin_memory = True )
50
+ args .way = valset .num_class
51
+ args .shot = 1
52
+
53
+ # construct model
54
+ model = Classifier (args )
55
+ if args .model_type == 'ConvNet' :
56
+ optimizer = torch .optim .Adam (model .parameters (), lr = args .lr , weight_decay = 0.0005 )
57
+ elif args .model_type == 'ResNet' :
58
+ optimizer = torch .optim .SGD (model .parameters (), lr = args .lr , momentum = 0.9 , nesterov = True , weight_decay = 0.0005 )
59
+ else :
60
+ raise ValueError ('No Such Encoder' )
61
+ criterion = torch .nn .CrossEntropyLoss ()
62
+
63
+ if torch .cuda .is_available ():
64
+ torch .backends .cudnn .benchmark = True
65
+ if args .ngpu > 1 :
66
+ model = torch .nn .DataParallel (model , device_ids = list (range (args .ngpu )))
67
+
68
+ model = model .cuda ()
69
+ criterion = criterion .cuda ()
70
+
71
+ def save_model (name ):
72
+ torch .save (dict (params = model .state_dict ()), osp .join (args .save_path , name + '.pth' ))
73
+
74
+ def save_checkpoint (is_best , filename = 'checkpoint.pth.tar' ):
75
+ state = {'epoch' : epoch + 1 ,
76
+ 'args' : args ,
77
+ 'state_dict' : model .state_dict (),
78
+ 'trlog' : trlog ,
79
+ 'val_acc' : trlog ['max_acc' ],
80
+ 'optimizer' : optimizer .state_dict (),
81
+ 'global_count' : global_count }
82
+
83
+ torch .save (state , osp .join (args .save_path , filename ))
84
+ if is_best :
85
+ shutil .copyfile (osp .join (args .save_path , filename ), osp .join (args .save_path , 'model_best.pth.tar' ))
86
+
87
+ if args .resume == True :
88
+ # load checkpoint
89
+ state = torch .load (osp .join (args .save_path , 'model_best.pth.tar' ))
90
+ init_epoch = state ['epoch' ]
91
+ resumed_state = state ['state_dict' ]
92
+ # resumed_state = {'module.'+k:v for k,v in resumed_state.items()}
93
+ model .load_state_dict (resumed_state )
94
+ trlog = state ['trlog' ]
95
+ optimizer .load_state_dict (state ['optimizer' ])
96
+ initial_lr = optimizer .param_groups [0 ]['lr' ]
97
+ global_count = state ['global_count' ]
98
+ else :
99
+ init_epoch = 1
100
+ trlog = {}
101
+ trlog ['args' ] = vars (args )
102
+ trlog ['train_loss' ] = []
103
+ trlog ['val_loss' ] = []
104
+ trlog ['train_acc' ] = []
105
+ trlog ['val_acc' ] = []
106
+ trlog ['max_acc' ] = 0.0
107
+ trlog ['max_acc_epoch' ] = 0
108
+ initial_lr = args .lr
109
+ global_count = 0
110
+
111
+ timer = Timer ()
112
+ writer = SummaryWriter (logdir = args .save_path ) # should change to log_dir for previous version tensorboardX
113
+ for epoch in range (init_epoch , args .max_epoch + 1 ):
114
+ # refine the step-size
115
+ if epoch in args .schedule :
116
+ initial_lr *= args .gamma
117
+ for param_group in optimizer .param_groups :
118
+ param_group ['lr' ] = initial_lr
119
+
120
+ model .train ()
121
+ tl = Averager ()
122
+ ta = Averager ()
123
+
124
+ for i , batch in enumerate (train_loader , 1 ):
125
+ global_count = global_count + 1
126
+ if torch .cuda .is_available ():
127
+ data , label = [_ .cuda () for _ in batch ]
128
+ label = label .type (torch .cuda .LongTensor )
129
+ else :
130
+ data , label = batch
131
+ label = label .type (torch .LongTensor )
132
+ logits = model (data )
133
+ loss = criterion (logits , label )
134
+ acc = count_acc (logits , label )
135
+ writer .add_scalar ('data/loss' , float (loss ), global_count )
136
+ writer .add_scalar ('data/acc' , float (acc ), global_count )
137
+ print ('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}' .format (epoch , i , len (train_loader ), loss .item (), acc ))
138
+
139
+ tl .add (loss .item ())
140
+ ta .add (acc )
141
+
142
+ optimizer .zero_grad ()
143
+ loss .backward ()
144
+ optimizer .step ()
145
+
146
+ tl = tl .item ()
147
+ ta = ta .item ()
148
+
149
+ # do not do validation in first 500 epoches
150
+ if epoch > 30 or epoch % 5 == 0 :
151
+ model .eval ()
152
+ vl = Averager ()
153
+ va = Averager ()
154
+ print ('best epoch {}, current best val acc={:.4f}' .format (trlog ['max_acc_epoch' ], trlog ['max_acc' ]))
155
+ # test performance with Few-Shot
156
+ label = torch .arange (valset .num_class ).repeat (15 )
157
+ if torch .cuda .is_available ():
158
+ label = label .type (torch .cuda .LongTensor )
159
+ else :
160
+ label = label .type (torch .LongTensor )
161
+ with torch .no_grad ():
162
+ for i , batch in tqdm (enumerate (val_loader , 1 )):
163
+ if torch .cuda .is_available ():
164
+ data , _ = [_ .cuda () for _ in batch ]
165
+ else :
166
+ data , _ = batch
167
+ data_shot , data_query = data [:valset .num_class ], data [valset .num_class :] # 16-way test
168
+ if args .ngpu > 1 :
169
+ logits = model .module .forward_proto (data_shot , data_query , valset .num_class )
170
+ else :
171
+ logits = model .forward_proto (data_shot , data_query , valset .num_class )
172
+ loss = F .cross_entropy (logits , label )
173
+ acc = count_acc (logits , label )
174
+ vl .add (loss .item ())
175
+ va .add (acc )
176
+
177
+ vl = vl .item ()
178
+ va = va .item ()
179
+ writer .add_scalar ('data/val_loss' , float (vl ), epoch )
180
+ writer .add_scalar ('data/val_acc' , float (va ), epoch )
181
+ print ('epoch {}, val, loss={:.4f} acc={:.4f}' .format (epoch , vl , va ))
182
+
183
+ if va > trlog ['max_acc' ]:
184
+ trlog ['max_acc' ] = va
185
+ trlog ['max_acc_epoch' ] = epoch
186
+ save_model ('max_acc' )
187
+ save_checkpoint (True )
188
+
189
+ trlog ['train_loss' ].append (tl )
190
+ trlog ['train_acc' ].append (ta )
191
+ trlog ['val_loss' ].append (vl )
192
+ trlog ['val_acc' ].append (va )
193
+ save_model ('epoch-last' )
194
+
195
+ print ('ETA:{}/{}' .format (timer .measure (), timer .measure (epoch / args .max_epoch )))
196
+ writer .close ()
197
+
198
+
199
+ import pdb
200
+ pdb .set_trace ()
0 commit comments