Skip to content

Commit

Permalink
Fix validation happening twice at the end of epoch (facebookresearch#…
Browse files Browse the repository at this point in the history
…1934)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes validation happening twice at the end of epoch after refactor. Spotted by freewym
 here: facebookresearch@b5dad3b#r38103577

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: facebookresearch#1934

Reviewed By: myleott

Differential Revision: D20724205

Pulled By: louismartin

fbshipit-source-id: 8c26c39b9904508780e8542813797c8e1306ca80
  • Loading branch information
louismartin authored and facebook-github-bot committed Apr 3, 2020
1 parent b5a6cef commit 18831f9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
20 changes: 7 additions & 13 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,18 @@ def main(args, init_distributed=False):

# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
valid_subsets = args.valid_subset.split(',')
while (
lr > args.min_lr
and epoch_itr.next_epoch_idx <= max_epoch
):
# train for one epoch
should_end_training = train(args, trainer, task, epoch_itr)

valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
valid_losses = train(args, trainer, task, epoch_itr, max_update)
if should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update:
break

# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
Expand All @@ -109,9 +109,6 @@ def main(args, init_distributed=False):
# sharded data: get train iterator for next epoch
load_dataset=(os.pathsep in getattr(args, 'data', '')),
)

if should_end_training:
break
train_meter.stop()
logger.info('done training in {:.1f} seconds'.format(train_meter.sum))

Expand Down Expand Up @@ -141,8 +138,8 @@ def is_better(a, b):


@metrics.aggregate('train')
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
def train(args, trainer, task, epoch_itr, max_update=math.inf):
"""Train the model for one epoch and return validation losses."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
Expand All @@ -169,8 +166,6 @@ def train(args, trainer, task, epoch_itr):
task.begin_epoch(epoch_itr.epoch, trainer.get_model())

valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
should_end_training = False
for samples in progress:
with metrics.aggregate('train_inner'):
log_output = trainer.train_step(samples)
Expand All @@ -189,7 +184,6 @@ def train(args, trainer, task, epoch_itr):

valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
should_end_training = True
break

# log end-of-epoch stats
Expand All @@ -198,7 +192,7 @@ def train(args, trainer, task, epoch_itr):

# reset epoch-level meters
metrics.reset_meters('train')
return should_end_training
return valid_losses


def validate_and_save(args, trainer, task, epoch_itr, valid_subsets):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def _test_reproducibility(
resume_checkpoint='checkpoint1.pt',
max_epoch=3,
):
def get_last_log_stats_containing_string(log_records, search_string):
for log_record in logs.records[::-1]:
if search_string in log_record.msg:
return json.loads(log_record.msg)

if extra_flags is None:
extra_flags = []

Expand All @@ -43,7 +48,8 @@ def _test_reproducibility(
'--max-epoch', str(max_epoch),
] + extra_flags,
)
train_log, valid_log = map(lambda rec: json.loads(rec.msg), logs.records[-4:-2])
train_log = get_last_log_stats_containing_string(logs.records, 'train_loss')
valid_log = get_last_log_stats_containing_string(logs.records, 'valid_loss')

# train epoch 2, resuming from previous checkpoint 1
os.rename(
Expand All @@ -59,7 +65,8 @@ def _test_reproducibility(
'--max-epoch', str(max_epoch),
] + extra_flags,
)
train_res_log, valid_res_log = map(lambda rec: json.loads(rec.msg), logs.records[-4:-2])
train_res_log = get_last_log_stats_containing_string(logs.records, 'train_loss')
valid_res_log = get_last_log_stats_containing_string(logs.records, 'valid_loss')

for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']:
self.assertAlmostEqual(float(train_log[k]), float(train_res_log[k]), delta=delta)
Expand Down

0 comments on commit 18831f9

Please sign in to comment.