Skip to content

Commit

Permalink
Add training wall time meter
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Sep 3, 2018
1 parent f84e1ed commit 9c10278
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 58 deletions.
133 changes: 75 additions & 58 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from fairseq import distributed_utils, optim, utils
from fairseq.meters import AverageMeter, TimeMeter
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler


Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, args, task, model, criterion):
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds

self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None
Expand Down Expand Up @@ -109,16 +110,24 @@ def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=Fa
self.meters = extra_state['train_meters']
del extra_state['train_meters']

# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()

return extra_state

def train_step(self, sample, update_params=True):
def train_step(self, sample, update_params=True, dummy_batch=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

if not dummy_batch:
self.meters['train_wall'].start()

# forward and backward pass
sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
Expand All @@ -132,62 +141,70 @@ def train_step(self, sample, update_params=True):

# update parameters
if update_params:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)

if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None

# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom)
self._opt()

# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)

# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))

self.clear_buffered_stats()

return agg_logging_output
agg_logging_output = self._update_params()
else:
return None # buffering updates
agg_logging_output = None # buffering updates

if not dummy_batch:
self.meters['train_wall'].stop()

return agg_logging_output

def _update_params(self):
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)

if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None

# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom)
self._opt()

# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)

# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))

self.clear_buffered_stats()

return agg_logging_output

def _forward(self, sample, eval=False):
loss = None
Expand Down Expand Up @@ -320,7 +337,7 @@ def valid_step(self, sample):

def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.train_step(dummy_batch, update_params=False, dummy_batch=True)
self.zero_grad()
self.clear_buffered_stats()

Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_training_stats(trainer):
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
return stats


Expand Down

0 comments on commit 9c10278

Please sign in to comment.