Skip to content

Commit

Permalink
broadcast the steps
Browse files Browse the repository at this point in the history
  • Loading branch information
justusc committed Jul 13, 2023
1 parent f711a08 commit d33f297
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions training/dist_clm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,38 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
if do_sync_before_save:
pipe.dp_optim.rollback_parameters()

def calculate_training_steps(args, train_data_loader) -> int:
if args.total_steps is None and args.nepochs is None:
return len(train_data_loader)
def calculate_training_steps(args, device, train_data_loader) -> int:

use_dp = (args.world_size != args.pipeline_group_size)
total_steps_sync = torch.zeros(1, dtype=torch.int64).to(device)

if args.total_steps is not None:
if args.total_steps is None and args.nepochs is None:
total_steps = len(train_data_loader)
elif args.total_steps is not None:
if args.nepochs is not None:
print("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs}).")
return args.total_steps
total_steps = args.total_steps
elif train_data_loader is not None:
token_count = train_data_loader.dataset.get_dataset_token_count()

# Check the inputs to calculate the total steps
if args.batch_size is None or args.world_size is None or args.pipeline_group_size is None or token_count is None or args.seq_length is None:
print("Missing required arguments for calculating total steps based on epochs.")
sys.exit(1)
global_batch_size = int(args.batch_size * args.world_size / args.pipeline_group_size)
total_steps = int((args.nepochs * token_count) / (global_batch_size * args.seq_length))

total_steps_sync.data[:] = total_steps

if use_dp:
get_data_parallel_comm().broadcast(total_steps_sync, 0)
get_pipeline_parallel_comm().broadcast(total_steps_sync, 0)
total_steps = total_steps_sync.item()

print(f"Rank {get_pipeline_parallel_rank()} calculated {total_steps} total steps.")

token_count = train_data_loader.dataset.get_dataset_token_count()
return total_steps

# Check the inputs to calculate the total steps
if args.batch_size is None or args.world_size is None or args.pipeline_group_size is None or token_count is None or args.seq_length is None:
print("Missing required arguments for calculating total steps based on epochs.")
sys.exit(1)
global_batch_size = int(args.batch_size * args.world_size / args.pipeline_group_size)
return int((args.nepochs * token_count) / (global_batch_size * args.seq_length))

def main():
parser = argparse.ArgumentParser(description='Gpipe-GPT')
Expand Down Expand Up @@ -391,7 +406,7 @@ def main():
test_data_loader = None

# calculate total steps
args.total_steps = calculate_training_steps(args, train_data_loader)
args.total_steps = calculate_training_steps(args, device, train_data_loader)
if args.checkpoint_steps == 0 and args.num_checkpoints > 0:
args.checkpoint_steps = int(args.total_steps / args.num_checkpoints)
if args.checkpoint_steps < 1:
Expand Down

0 comments on commit d33f297

Please sign in to comment.