Skip to content

Commit

Permalink
fix criterion name check when resuming from checkpoint
Browse files Browse the repository at this point in the history
Summary:
I tried resuming a run from a checkpoint in f250883864, but ran into:

AssertionError: Criterion does not match; please reset the optimizer (--reset-optimizer). DistributedTimeoutWrapper vs ContrastiveLabelsCriterion

Based on this, I believe since D25836853 (facebookresearch@d68a353) we are no longer saving the actual criterion's name, but DistributedTimeoutWrapper in the checkpoint.

This is kind of weird though, as I would expect more people to run into this issue. Not sure if I am doing something wrong, let me know if so, thanks!

Reviewed By: myleott

Differential Revision: D26478656

fbshipit-source-id: bc3c7c925f5505140d9df4438af3a73d65d4f531
  • Loading branch information
Alex Xiao authored and facebook-github-bot committed Feb 19, 2021
1 parent 3ef1888 commit c6b5c00
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
@@ -284,7 +284,7 @@ def save_checkpoint(self, filename, extra_state):
filename,
self.cfg,
self.model.state_dict(),
self.criterion,
self.get_criterion(),
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
@@ -375,10 +375,10 @@ def load_checkpoint(
last_optim = self._optim_history[-1]
assert (
last_optim["criterion_name"] == self.get_criterion().__class__.__name__
), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
assert (
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."
), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"

if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])

0 comments on commit c6b5c00

Please sign in to comment.