Skip to content

Commit

Permalink
fix val loss computation in megatron (NVIDIA#5871)
Browse files Browse the repository at this point in the history
* fix val loss computation in megatron

* Fix NaN handling during validation

---------

Co-authored-by: ANMOL GUPTA <[email protected]>
Co-authored-by: Mikołaj Błaż <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Feb 6, 2023
1 parent 5c4e62b commit b81cd19
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,24 +515,30 @@ def fwd_output_and_loss_func(batch, model, checkpoint_activations_all_layers=Non
)

def loss_func(output_tensor):
loss_for_mb = self.loss_func(loss_mask, output_tensor)
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(loss_mask, output_tensor)
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_samples_in_mb = int(loss_mask.sum() / loss_mask.numel() * loss_mask.shape[0])
loss_sum_for_mb = num_valid_samples_in_mb * loss_for_mb
loss_sum_and_mb_size_all_gpu = torch.cat(
num_valid_tokens_in_ub = loss_mask.sum()
if loss_for_ub.isnan():
assert loss_mask.count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
else:
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub

loss_sum_and_ub_size_all_gpu = torch.cat(
[
loss_sum_for_mb.clone().detach().view(1),
torch.tensor([num_valid_samples_in_mb]).cuda().clone().detach(),
loss_sum_for_ub.clone().detach().view(1),
torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(),
]
)
# Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds)
torch.distributed.all_reduce(
loss_sum_and_mb_size_all_gpu, group=parallel_state.get_data_parallel_group()
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
return loss_for_mb, {'loss_sum_and_mb_size': loss_sum_and_mb_size_all_gpu}
return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_mb])
return loss_for_mb, {'avg': reduced_loss}
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub, {'avg': reduced_loss}

return output_tensor, loss_func

Expand Down Expand Up @@ -603,9 +609,9 @@ def validation_step(self, batch, batch_idx):
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_mb_size']
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_mb_size'][1] > 0
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
Expand Down

0 comments on commit b81cd19

Please sign in to comment.