Skip to content

Commit

Permalink
access to pipeline_model_parallel_split_rank (NVIDIA#1300)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Feb 23, 2022
1 parent ab1a93a commit 069ff33
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
6 changes: 6 additions & 0 deletions apex/transformer/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())


def get_pipeline_model_parallel_split_rank():
"""Return my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK


def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
Expand Down
21 changes: 21 additions & 0 deletions tests/L0/run_transformer/run_initialize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is None

# Reset groups
parallel_state.destroy_model_parallel()

torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')


def test_pipeline_model_parallel_split_rank():
pipeline_model_parallel_split_rank_ = 1
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_)
assert parallel_state.model_parallel_is_initialized()

split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is pipeline_model_parallel_split_rank_

# Reset groups
parallel_state.destroy_model_parallel()
Expand All @@ -101,4 +120,6 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
print_separator('test pipeline model parallel split rank')
test_pipeline_model_parallel_split_rank()
tensor_model_parallel_size *= 2

0 comments on commit 069ff33

Please sign in to comment.