Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cw1204772 committed Apr 4, 2018
2 parents 31d2d73 + 2a58fff commit a1a551a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
64 changes: 62 additions & 2 deletions ReID/ReID_CNN/train_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.autograd import Variable
import torch.optim as optim
import torch.backends.cudnn as cudnn
from utils import TripletImage_Dataset, sv_comp_Dataset, BoxCars_Dataset, Unsupervised_TripletImage_Dataset, Get_train_DataLoader, Get_val_DataLoader
from utils import TripletImage_Dataset, sv_comp_Dataset, BoxCars_Dataset, Unsupervised_TripletImage_Dataset, Get_train_DataLoader, Get_val_DataLoader,Get_test_DataLoader
from loss import TripletLoss
from logger import Logger
import models
Expand All @@ -19,6 +19,8 @@ def train_joint(args, train_veri_dataloader,
train_compcars_dataloader, val_compcars_dataloader,
train_boxcars_dataloader, val_boxcars_dataloader,
train_aic_dataloader,
test_compcars_dataloader,
test_boxcars_dataloader,
base_net, veri_id_net, color_net, compcars_model_net, boxcars_model_net):

optimizer_base = optim.Adam(base_net.parameters(), lr=args.lr)
Expand All @@ -30,6 +32,7 @@ def train_joint(args, train_veri_dataloader,
criterion_ce = nn.CrossEntropyLoss()
logger = Logger(os.path.join(args.save_model_dir,'train'))
val_logger = Logger(os.path.join(args.save_model_dir,'val'))
test_logger = Logger(os.path.join(args.save_model_dir,'test'))

epoch_size = min(len(train_veri_dataloader), len(train_compcars_dataloader), len(train_boxcars_dataloader), len(train_aic_dataloader))
for e in range(args.n_epochs):
Expand Down Expand Up @@ -215,20 +218,67 @@ def train_joint(args, train_veri_dataloader,
acc_model = torch.cat(correct_model).float().mean()
print('BoxCars model val acc: %.3f' % acc_model)
val_logger.logg({'boxcars_model_acc':acc_model})
val_logger.write_log()

if e%25 == 0:
print('start testing')
test_logger.append_epoch(e)
pbar = tqdm(total=len(test_boxcars_dataloader),ncols=100,leave=True)
pbar.set_description('Test BoxCar')
correct_model = []
for i,sample in enumerate(test_boxcars_dataloader):
img = Variable(sample['img'],volatile=True).cuda()
gt_model = Variable(sample['model'],volatile=True).cuda()
pred_feat = base_net(img)
pred_model =boxcars_model_net(pred_feat)
_, pred_model = torch.max(pred_model,dim=1)
correct_model.append(pred_model.data == gt_model.data)
pbar.update(1)
pbar.close()
acc_model = torch.cat(correct_model).float().mean()
print('BoxCars model val acc: %.3f' % acc_model)
test_logger.logg({'boxcars_model_acc':acc_model})

pbar = tqdm(total=len(test_compcars_dataloader),ncols=100,leave=True)
pbar.set_description('Test CompCar_sv')
correct_model = []
correct_color = []
for i,sample in enumerate(test_compcars_dataloader):
img = Variable(sample['img'],volatile=True).cuda()
gt_model = Variable(sample['model'],volatile=True).cuda()
gt_color = Variable(sample['color'],volatile=True).cuda()
pred_feat = base_net(img)
pred_model = compcars_model_net(pred_feat)
pred_color = color_net(pred_feat)
_, pred_model = torch.max(pred_model,dim=1)
_, pred_color = torch.max(pred_color,dim=1)
correct_model.append(pred_model.data == gt_model.data)
correct_color.append(pred_color.data == gt_color.data)
pbar.update(1)
pbar.close()
acc_model = torch.cat(correct_model).float().mean()
acc_color = torch.cat(correct_color).float().mean()
print('CompCars model val acc: %.3f' % acc_model)
print('CompCars color val acc: %.3f' % acc_color)
test_logger.logg({'compcars_model_acc':acc_model})
test_logger.logg({'compcars_color_acc':acc_color})
test_logger.write_log()

base_net.train()
veri_id_net.train()
color_net.train()
compcars_model_net.train()
boxcars_model_net.train()
val_logger.write_log()


if __name__ == '__main__':
## Parse arg
parser = argparse.ArgumentParser(description='Train Re-ID net', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--veri_txt', required=True, help='txt for VeRi dataset')
parser.add_argument('--compcars_txt', required=True, help='txt for CompCars sv dataset')
parser.add_argument('--compcars_test_txt', required=True, help='txt for CompCars sv dataset')
parser.add_argument('--boxcars_txt', required=True, help='txt for BoxCars116k dataset')
parser.add_argument('--boxcars_test_txt', required=True, help='txt for BoxCars116k dataset')
parser.add_argument('--aic_pkl', required=True, help='pkl for AIC dataset')
parser.add_argument('--crop',type=bool,default=True,help='Whether crop the images')
parser.add_argument('--flip',type=bool,default=True,help='Whether randomly flip the image')
Expand Down Expand Up @@ -281,6 +331,14 @@ def train_joint(args, train_veri_dataloader,
imagenet_normalize=args.pretrain, batch_size=args.batch_size)
train_aic_dataloader = Get_train_DataLoader(aic_dataset, batch_size=1)

# Test Dataset & loader

compcars_dataset_test = sv_comp_Dataset(args.compcars_test_txt, crop=False, flip=False, jitter=False,imagenet_normalize=True)
test_compcars_dataloader = Get_test_DataLoader(compcars_dataset_test, batch_size=args.batch_size)

boxcars_dataset = BoxCars_Dataset(args.boxcars_test_txt, crop=False, flip=False, jitter=False, imagenet_normalize=True)
test_boxcars_dataloader = Get_test_DataLoader(boxcars_dataset, batch_size=args.batch_size)

# Get Model
base_net = models.FeatureResNet(n_layers=args.n_layer, pretrained=args.pretrain)
veri_id_net = models.NLayersFC(base_net.output_dim, veri_dataset.n_id)
Expand All @@ -303,5 +361,7 @@ def train_joint(args, train_veri_dataloader,
train_compcars_dataloader, val_compcars_dataloader, \
train_boxcars_dataloader, val_boxcars_dataloader, \
train_aic_dataloader, \
test_compcars_dataloader,\
test_boxcars_dataloader,\
base_net, veri_id_net, color_net, compcars_model_net, boxcars_model_net)

5 changes: 4 additions & 1 deletion ReID/ReID_CNN/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,14 @@ def __getitem__(self, idx):
def __len__(self):
return self.len

def Get_test_DataLoader(dataset,batch_size=128,shuffle=False,num_workers=6):
return DataLoader(dataset,batch_size = batch_size,shuffle=shuffle,num_workers=num_workers)

def Get_train_DataLoader(dataset,batch_size=128,shuffle=True,num_workers=6):
sampler = SubsetRandomSampler(dataset.train_index)
return DataLoader(dataset,batch_size = batch_size,sampler=sampler,num_workers=num_workers)

def Get_val_DataLoader(dataset,batch_size=128,shuffle=True,num_workers=6):
def Get_val_DataLoader(dataset,batch_size=128,shuffle=False,num_workers=6):
sampler = SubsetRandomSampler(dataset.val_index)
return DataLoader(dataset,batch_size = batch_size,sampler=sampler,num_workers=num_workers)

Expand Down

0 comments on commit a1a551a

Please sign in to comment.