Skip to content

Commit

Permalink
Refactor run_transformer bert minimal and gpt minimal tests (NVIDIA#1540
Browse files Browse the repository at this point in the history
)

* working test_bert_minimal.py

* remove some debugging statements

* working test_gpt_minimal.py

* test_dynamic_batchsize.py having issues with torch.backends.cudnn.allow_tf32

* working test_dynamic_batchsize.py

* refactor test_bert_minimal.py, need to investigate rng of MANUAL_SEED for nccl only pipeline with virtual_pipeline_model_parallel_size = 2

* add test_bert_minimal_alt.py for visibility

* update test_gpt_minimal.py

* lint

* update loss cutoff for bert test

* split with / without interleaving tests for bert

* use skipTest

* remove ONCE

* add ignore_unknown_args=True

* remove old testing files

* add num_devices logic to override_args
  • Loading branch information
Fuzzkatt authored Dec 8, 2022
1 parent 190a446 commit 85af77c
Show file tree
Hide file tree
Showing 8 changed files with 588 additions and 645 deletions.
8 changes: 7 additions & 1 deletion apex/transformer/testing/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch

def parse_args(extra_args_provider=None, defaults={},
def parse_args(extra_args_provider=None, defaults={}, override_args={},
ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
Expand Down Expand Up @@ -59,16 +59,22 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

for key in override_args:
setattr(args, key, override_args[key])

# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)

# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))

args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
Expand Down
6 changes: 4 additions & 2 deletions apex/transformer/testing/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def get_timers():
return _GLOBAL_TIMERS


def set_global_variables(extra_args_provider=None, args_defaults={},
def set_global_variables(extra_args_provider=None, args_defaults={}, override_args={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
override_args=override_args,
ignore_unknown_args=ignore_unknown_args)
# _build_num_microbatches_calculator(args)
# if args.vocab_file:
Expand All @@ -98,13 +99,14 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_timers()


def _parse_args(extra_args_provider=None, defaults={},
def _parse_args(extra_args_provider=None, defaults={}, override_args={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
override_args=override_args,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS

Expand Down
254 changes: 0 additions & 254 deletions tests/L0/run_transformer/run_bert_minimal_test.py

This file was deleted.

Loading

0 comments on commit 85af77c

Please sign in to comment.