@@ -351,24 +351,31 @@ def train(config: Config):
351
351
flatten_logits = rearrange (logits , "b seq vocab -> (b seq) vocab" )
352
352
flatten_labels = rearrange (labels , "b seq -> (b seq)" )
353
353
354
- if config .optim .z_loss is not None :
354
+ if config .optim .z_loss :
355
355
ce_loss , z_loss = cross_entropy_max_z_loss (
356
356
flatten_logits , flatten_labels , config .optim .z_loss_weight
357
357
)
358
-
359
358
ce_loss /= gradient_accumulation_steps
360
359
z_loss /= gradient_accumulation_steps
361
360
362
- loss_batch += ce_loss .detach ()
363
- z_loss_batch += z_loss .detach ()
364
-
361
+ del logits
365
362
loss = ce_loss + z_loss
363
+ loss .backward ()
366
364
367
365
else :
368
366
loss = F .cross_entropy (flatten_logits , flatten_labels ) / gradient_accumulation_steps
369
- loss_batch += loss .detach ()
367
+ del logits
368
+ loss .backward ()
369
+
370
+ if config .optim .z_loss :
371
+ loss_batch += ce_loss .clone ().detach ()
372
+ z_loss_batch += z_loss .clone ().detach ()
373
+ else :
374
+ loss_batch += loss .clone ().detach ()
370
375
371
- loss .backward ()
376
+ dist .all_reduce (tensor = loss_batch , op = dist .ReduceOp .AVG , group = elastic_device_mesh .local_pg )
377
+ if config .optim .z_loss :
378
+ dist .all_reduce (tensor = z_loss_batch , op = dist .ReduceOp .AVG , group = elastic_device_mesh .local_pg )
372
379
373
380
torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
374
381
inner_optimizer .step ()
@@ -379,9 +386,6 @@ def train(config: Config):
379
386
training_progress .step += 1
380
387
inner_lr = [group ["lr" ] for group in inner_optimizer .param_groups ][0 ]
381
388
382
- dist .all_reduce (tensor = loss_batch , op = dist .ReduceOp .AVG , group = elastic_device_mesh .local_pg )
383
- dist .all_reduce (tensor = z_loss_batch , op = dist .ReduceOp .AVG , group = elastic_device_mesh .local_pg )
384
-
385
389
# syncing loss across all data parallel rank within a nodes
386
390
387
391
new_tokens = config .data .seq_length * config .optim .batch_size
0 commit comments