Skip to content

Commit 48c2a60

Browse files
authoredOct 9, 2024
* maybe fix oom * Revert "maybe fix oom" This reverts commit 0c09f83. * maybe fix oom 2 * maybe fix oom 3 * Revert "maybe fix oom 3" This reverts commit 4341086. * maybe fix oom 4 * maybe fix oom 5 * Revert "maybe fix oom 5" This reverts commit 55b17ec. * maybe fix oom 6 * maybe fix oom 6 * maybe fix oom 9
1 parent 9db81b0 commit 48c2a60

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed
 

‎src/zeroband/loss.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from torch import Tensor
2+
import torch
23
import torch.nn.functional as F
34

45

6+
@torch.compile
57
def cross_entropy_max_z_loss(
68
logits: Tensor,
79
targets: Tensor,

‎src/zeroband/train.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -351,24 +351,31 @@ def train(config: Config):
351351
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
352352
flatten_labels = rearrange(labels, "b seq -> (b seq)")
353353

354-
if config.optim.z_loss is not None:
354+
if config.optim.z_loss:
355355
ce_loss, z_loss = cross_entropy_max_z_loss(
356356
flatten_logits, flatten_labels, config.optim.z_loss_weight
357357
)
358-
359358
ce_loss /= gradient_accumulation_steps
360359
z_loss /= gradient_accumulation_steps
361360

362-
loss_batch += ce_loss.detach()
363-
z_loss_batch += z_loss.detach()
364-
361+
del logits
365362
loss = ce_loss + z_loss
363+
loss.backward()
366364

367365
else:
368366
loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps
369-
loss_batch += loss.detach()
367+
del logits
368+
loss.backward()
369+
370+
if config.optim.z_loss:
371+
loss_batch += ce_loss.clone().detach()
372+
z_loss_batch += z_loss.clone().detach()
373+
else:
374+
loss_batch += loss.clone().detach()
370375

371-
loss.backward()
376+
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
377+
if config.optim.z_loss:
378+
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
372379

373380
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
374381
inner_optimizer.step()
@@ -379,9 +386,6 @@ def train(config: Config):
379386
training_progress.step += 1
380387
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]
381388

382-
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
383-
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
384-
385389
# syncing loss across all data parallel rank within a nodes
386390

387391
new_tokens = config.data.seq_length * config.optim.batch_size

0 commit comments

Comments
 (0)
Please sign in to comment.