Skip to content

Commit

Permalink
Fixed #70 (#71)
Browse files Browse the repository at this point in the history
* Fixed #70
  • Loading branch information
kylegao91 authored Sep 7, 2017
1 parent 33bd085 commit b8104bc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
8 changes: 6 additions & 2 deletions seq2seq/trainer/supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
total_steps = steps_per_epoch * n_epochs

step = start_step
step_elapsed = 0
for epoch in range(start_epoch, n_epochs + 1):
log.debug("Epoch: %d, Step: %d" % (epoch, step))

Expand All @@ -100,6 +101,7 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
model.train(True)
for batch in batch_generator:
step += 1
step_elapsed += 1

input_variables, input_lengths = getattr(batch, seq2seq.src_field_name)
target_variables = getattr(batch, seq2seq.tgt_field_name)
Expand All @@ -110,8 +112,8 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
print_loss_total += loss
epoch_loss_total += loss

if step % self.print_every == 0:
print_loss_avg = print_loss_total / min(self.print_every, step - start_step)
if step % self.print_every == 0 and step_elapsed > self.print_every:
print_loss_avg = print_loss_total / self.print_every
print_loss_total = 0
log_msg = 'Progress: %d%%, Train %s: %.4f' % (
step / total_steps * 100,
Expand All @@ -127,6 +129,8 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
input_vocab=data.fields[seq2seq.src_field_name].vocab,
output_vocab=data.fields[seq2seq.tgt_field_name].vocab).save(self.expt_dir)

if step_elapsed == 0: continue

epoch_loss_avg = epoch_loss_total / min(steps_per_epoch, step - start_step)
epoch_loss_total = 0
log_msg = "Finished epoch %d: Train %s: %.4f" % (epoch, self.loss.name, epoch_loss_avg)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,18 @@ def test_batch_num_when_resuming(self, mock_checkpoint, mock_func):

self.assertEqual(steps_per_epoch - step, mock_func.call_count)

@mock.patch('seq2seq.trainer.SupervisedTrainer._train_batch', return_value=0)
@mock.patch('seq2seq.util.checkpoint.Checkpoint.save')
def test_resume_from_multiple_of_epoches(self, mock_checkpoint, mock_func):
mock_model = mock.Mock()
mock_optim = mock.Mock()

trainer = SupervisedTrainer(batch_size=16)
trainer.optimizer = mock_optim
n_epoches = 1
start_epoch = 1
step = 7
trainer._train_epoches(self.dataset, mock_model, n_epoches, start_epoch, step)

if __name__ == '__main__':
unittest.main()

0 comments on commit b8104bc

Please sign in to comment.