Skip to content

Commit

Permalink
update model selection
Browse files Browse the repository at this point in the history
  • Loading branch information
xwen99 committed Jun 7, 2023
1 parent 5d1b2c9 commit 05872e4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 62 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ We also use generic object recognition datasets, including:
bash scripts/run_${DATASET_NAME}.sh
```

Then check the results starting with `Metrics with best model on test set:` in the logs.
This means the model is picked according to its performance on the test set, and then evaluated on the unlabelled instances of the train set.
~~Then check the results starting with `Metrics with best model on test set:` in the logs.
This means the model is picked according to its performance on the test set, and then evaluated on the unlabelled instances of the train set.~~

We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model.

## Results
Our results in three independent runs:
Expand Down
59 changes: 30 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def train(student, train_loader, test_loader, unlabelled_train_loader, args):
args.teacher_temp,
)

# inductive
best_test_acc_lab = 0
# transductive
best_train_acc_lab = 0
best_train_acc_ubl = 0
best_train_acc_all = 0
# # inductive
# best_test_acc_lab = 0
# # transductive
# best_train_acc_lab = 0
# best_train_acc_ubl = 0
# best_train_acc_all = 0

for epoch in range(args.epochs):
loss_record = AverageMeter()
Expand Down Expand Up @@ -111,12 +111,12 @@ def train(student, train_loader, test_loader, unlabelled_train_loader, args):

args.logger.info('Testing on unlabelled examples in the training data...')
all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
args.logger.info('Testing on disjoint test set...')
all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args)
# args.logger.info('Testing on disjoint test set...')
# all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args)


args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
# args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))

# Step schedule
exp_lr_scheduler.step()
Expand All @@ -130,23 +130,23 @@ def train(student, train_loader, test_loader, unlabelled_train_loader, args):
torch.save(save_dict, args.model_path)
args.logger.info("model saved to {}.".format(args.model_path))

if old_acc_test > best_test_acc_lab:

args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))

torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))

# inductive
best_test_acc_lab = old_acc_test
# transductive
best_train_acc_lab = old_acc
best_train_acc_ubl = new_acc
best_train_acc_all = all_acc

args.logger.info(f'Exp Name: {args.exp_name}')
args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
# if old_acc_test > best_test_acc_lab:
#
# args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
# args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
#
# torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
# args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
#
# # inductive
# best_test_acc_lab = old_acc_test
# # transductive
# best_train_acc_lab = old_acc
# best_train_acc_ubl = new_acc
# best_train_acc_all = all_acc
#
# args.logger.info(f'Exp Name: {args.exp_name}')
# args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')


def test(model, test_loader, epoch, save_name, args):
Expand Down Expand Up @@ -283,8 +283,8 @@ def test(model, test_loader, epoch, save_name, args):
sampler=sampler, drop_last=True, pin_memory=True)
test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
batch_size=256, shuffle=False, pin_memory=False)
test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
batch_size=256, shuffle=False, pin_memory=False)
# test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
# batch_size=256, shuffle=False, pin_memory=False)

# ----------------------
# PROJECTION HEAD
Expand All @@ -295,4 +295,5 @@ def test(model, test_loader, epoch, save_name, args):
# ----------------------
# TRAIN
# ----------------------
train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args)
# train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args)
train(model, train_loader, None, test_loader_unlabelled, args)
62 changes: 31 additions & 31 deletions train_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def main(args):
sample_weights = torch.DoubleTensor(sample_weights)
train_sampler = DistributedWeightedSampler(train_dataset, sample_weights, num_samples=len(train_dataset))
unlabelled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabelled_train_examples_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
# test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
# --------------------
# DATALOADERS
# --------------------
train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
sampler=train_sampler, drop_last=True, pin_memory=True)
unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, batch_size=256,
shuffle=False, sampler=unlabelled_train_sampler, pin_memory=False)
test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256,
shuffle=False, sampler=test_sampler, pin_memory=False)
# test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256,
# shuffle=False, sampler=test_sampler, pin_memory=False)

# ----------------------
# PROJECTION HEAD
Expand Down Expand Up @@ -169,29 +169,29 @@ def main(args):
args.teacher_temp,
)

# inductive
best_test_acc_lab = 0
# transductive
best_train_acc_lab = 0
best_train_acc_ubl = 0
best_train_acc_all = 0
# # inductive
# best_test_acc_lab = 0
# # transductive
# best_train_acc_lab = 0
# best_train_acc_ubl = 0
# best_train_acc_all = 0

for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
train(model, train_loader, optimizer, fp16_scaler, exp_lr_scheduler, cluster_criterion, epoch, args)

unlabelled_train_sampler.set_epoch(epoch)
test_sampler.set_epoch(epoch)
# test_sampler.set_epoch(epoch)
if dist.get_rank() == 0:
args.logger.info('Testing on unlabelled examples in the training data...')
all_acc, old_acc, new_acc = test(model, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
if dist.get_rank() == 0:
args.logger.info('Testing on disjoint test set...')
all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args)
# if dist.get_rank() == 0:
# args.logger.info('Testing on disjoint test set...')
# all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args)

if dist.get_rank() == 0:
args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
# args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))

save_dict = {
'model': model.state_dict(),
Expand All @@ -202,23 +202,23 @@ def main(args):
torch.save(save_dict, args.model_path)
args.logger.info("model saved to {}.".format(args.model_path))

if old_acc_test > best_test_acc_lab and dist.get_rank() == 0:
args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))

torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))

# inductive
best_test_acc_lab = old_acc_test
# transductive
best_train_acc_lab = old_acc
best_train_acc_ubl = new_acc
best_train_acc_all = all_acc

if dist.get_rank() == 0:
args.logger.info(f'Exp Name: {args.exp_name}')
args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
# if old_acc_test > best_test_acc_lab and dist.get_rank() == 0:
# args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
# args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
#
# torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
# args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
#
# # inductive
# best_test_acc_lab = old_acc_test
# # transductive
# best_train_acc_lab = old_acc
# best_train_acc_ubl = new_acc
# best_train_acc_all = all_acc
#
# if dist.get_rank() == 0:
# args.logger.info(f'Exp Name: {args.exp_name}')
# args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')


def train(student, train_loader, optimizer, scaler, scheduler, cluster_criterion, epoch, args):
Expand Down

0 comments on commit 05872e4

Please sign in to comment.