Skip to content

Commit

Permalink
[Hetero] Fix bug during loading ckpt if enabling CPU communication (F…
Browse files Browse the repository at this point in the history
…lagOpen#248)

Fix bug during loading kept if enabling CPU communication

Co-authored-by: lizhiyu <[email protected]>
  • Loading branch information
heavyrain-lzy and lizhiyu authored Nov 1, 2024
1 parent de75f52 commit 9e5b3cb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/llama/conf/train/train_llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ model:
normalization: RMSNorm
rotary_interleaved_patch: False
position_embedding_type: rope
rotary_base: 500000.0
rotary_base: 500000
untie_embeddings_and_output_weights: True
init_method_std: 0.02
attention_dropout: 0.0
Expand Down
2 changes: 1 addition & 1 deletion megatron/megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def read_metadata(tracker_filename):

# Get the max iteration retrieved across the ranks.
if torch.distributed.is_initialized():
iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda')
iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda' if 'nccl' in torch.distributed.get_backend() else 'cpu')
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()

Expand Down

0 comments on commit 9e5b3cb

Please sign in to comment.