Skip to content

Commit

Permalink
Add an option to specify epochs rather than total steps
Browse files Browse the repository at this point in the history
  • Loading branch information
justusc committed Jul 13, 2023
1 parent fbd9c54 commit 1bd272c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
26 changes: 22 additions & 4 deletions training/dist_clm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,24 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
save_checkpoint(pipe, args)
if do_sync_before_save:
pipe.dp_optim.rollback_parameters()


def calculate_training_steps(args, train_data_loader) -> int:
if args.total_steps is None and args.nepochs is None:
return len(train_data_loader)

if args.total_steps is not None:
if args.nepochs is not None:
print("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs}).")
return args.total_steps

token_count = train_data_loader.get_dataset_token_count()

# Check the inputs to calculate the total steps
if args.batch_size is None or args.world_size is None or args.pipeline_group_size is None or token_count is None or args.seq_length is None:
print("Missing required arguments for calculating total steps based on epochs.")
sys.exit(1)
global_batch_size = int(args.batch_size * args.world_size / args.pipeline_group_size)
return int((args.nepochs * token_count) / (global_batch_size * args.seq_length))

def main():
parser = argparse.ArgumentParser(description='Gpipe-GPT')
Expand All @@ -283,6 +300,7 @@ def main():
help='task name')
parser.add_argument('--warmup-steps', type=int, default=0, help='-')
parser.add_argument('--train-warmup-steps', type=int, default=0, help='-')
parser.add_argument('--nepochs', type=int, default=None, help='-')
parser.add_argument('--total-steps', type=int, default=None, help='-')
parser.add_argument('--load-pretrained-model',
type=lambda x: x.lower()=='true', default=True, metavar='S',
Expand Down Expand Up @@ -368,9 +386,9 @@ def main():
test_data_loader = get_eval_data_loader(args, tokenizer)
else:
test_data_loader = None
if args.total_steps is None:
args.total_steps = len(train_data_loader)

# calculate total steps
args.total_steps = calculate_training_steps(args, train_data_loader)

use_dp = (args.world_size != args.pipeline_group_size)
if use_dp:
Expand Down
10 changes: 7 additions & 3 deletions training/tasks/data_loaders/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, task_names, datasets, sample_probs, tokenizer, seq_length=102
self.seq_length = seq_length
self.print_sample_every_n = print_sample_every_n
self.post_processor = post_processor
self.token_count = None

self.it = None

Expand Down Expand Up @@ -282,7 +283,10 @@ def tokenize_function(examples):
# Compute the number of tokens in a dataset using a Torch tokenizer
# - return: the sum of tokens from the the text field of each sample in the dataset
def get_dataset_token_count(self) -> int:
token_count = 0
if self.token_count is not None:
return self.token_count

self.token_count = 0

if self.task_names is None:
return token_count
Expand All @@ -303,12 +307,12 @@ def get_dataset_token_count(self) -> int:
)

for item in tokenized_datasets:
token_count += len(item['input_ids'])
self.token_count += len(item['input_ids'])

# clean up cache
raw_datasets.cleanup_cache_files()

return token_count
return self.token_count

def get_dataset_example_count(self) -> int:
num_lines = 0
Expand Down

0 comments on commit 1bd272c

Please sign in to comment.