Skip to content

Commit

Permalink
hotfix: Set resuming optimizer parameters with an walk around.
Browse files Browse the repository at this point in the history
  • Loading branch information
kylegao91 committed Aug 31, 2017
1 parent 1b24f35 commit 61818cf
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
7 changes: 7 additions & 0 deletions seq2seq/trainer/supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def train(self, model, data, num_epochs=5,
resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
model = resume_checkpoint.model
self.optimizer = resume_checkpoint.optimizer

# A walk around to set optimizing parameters properly
resume_optim = self.optimizer.optimizer
defaults = resume_optim.param_groups[0]
defaults.pop('params', None)
self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults)

start_epoch = resume_checkpoint.epoch
step = resume_checkpoint.step
else:
Expand Down
4 changes: 1 addition & 3 deletions seq2seq/util/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def save(self, experiment_dir):
os.makedirs(path)
torch.save({'epoch': self.epoch,
'step': self.step,
'optimizer': self.optimizer,
'optim_state_dict': self.optimizer.optimizer.state_dict()
'optimizer': self.optimizer
},
os.path.join(path, self.TRAINER_STATE_NAME))
torch.save(self.model, os.path.join(path, self.MODEL_NAME))
Expand Down Expand Up @@ -100,7 +99,6 @@ def load(cls, path):
with open(os.path.join(path, cls.OUTPUT_VOCAB_FILE), 'rb') as fin:
output_vocab = dill.load(fin)
optimizer = resume_checkpoint['optimizer']
optimizer.optimizer.load_state_dict(resume_checkpoint['optim_state_dict'])
return Checkpoint(model=model, input_vocab=input_vocab,
output_vocab=output_vocab,
optimizer=optimizer,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_save_checkpoint_calls_torch_save(self, mock_open, mock_dill, mock_torch
epoch = 5
step = 10
optim = mock.Mock()
state_dict = {'epoch': epoch, 'step': step, 'optimizer': optim, 'optim_state_dict': optim.optimizer.state_dict()}
state_dict = {'epoch': epoch, 'step': step, 'optimizer': optim}

mock_model = mock.Mock()
mock_vocab = mock.Mock()
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_save_checkpoint_calls_torch_save(self, mock_open, mock_dill, mock_torch
def test_load(self, mock_open, mock_dill, mock_torch):
dummy_vocabulary = mock.Mock()
mock_optimizer = mock.Mock()
torch_dict = {"optimizer": mock_optimizer, "epoch": 5, "step": 10, 'optim_state_dict': mock_optimizer}
torch_dict = {"optimizer": mock_optimizer, "epoch": 5, "step": 10}
mock_open.return_value = mock.MagicMock()
mock_torch.load.return_value = torch_dict
mock_dill.load.return_value = dummy_vocabulary
Expand Down

0 comments on commit 61818cf

Please sign in to comment.