Skip to content

Commit

Permalink
added restart
Browse files Browse the repository at this point in the history
  • Loading branch information
walid0925 committed Jul 14, 2021
1 parent 744d4c1 commit 9ab2f2d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
18 changes: 12 additions & 6 deletions chemberta/train/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ def train_flags():
help="The batch size per GPU/TPU core/CPU for training.",
module_name="training",
)
flags.DEFINE_integer(
name="save_steps",
default=100,
help="Number of updates steps before two checkpoint saves if save_strategy='steps'",
module_name="training",
)
# flags.DEFINE_integer(
# name="save_steps",
# default=100,
# help="Number of updates steps before two checkpoint saves if save_strategy='steps'",
# module_name="training",
# )
flags.DEFINE_integer(
name="save_total_limit",
default=None,
Expand All @@ -164,3 +164,9 @@ def train_flags():
help="Masking rate",
module_name="training",
)
flags.DEFINE_string(
name="cloud_directory",
default=None,
help="If provided, syncs the run directory here using a callback.",
module_name="training",
)
39 changes: 35 additions & 4 deletions chemberta/train/train_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
>
"""

import glob
import os
import subprocess

import s3fs
import torch
import yaml
from absl import app, flags
Expand All @@ -32,7 +35,7 @@
tokenizer_flags,
train_flags,
)
from chemberta.train.utils import DatasetArguments, create_trainer
from chemberta.train.utils import AwsS3Callback, DatasetArguments, create_trainer

# Model params
flags.DEFINE_enum(
Expand All @@ -59,6 +62,8 @@
def main(argv):
torch.manual_seed(0)
run_dir = os.path.join(FLAGS.output_dir, FLAGS.run_name)
if not os.path.isdir(run_dir):
os.makedirs(run_dir)

model_config = RobertaConfig(
vocab_size=FLAGS.vocab_size,
Expand Down Expand Up @@ -95,15 +100,32 @@ def main(argv):
num_train_epochs=FLAGS.num_train_epochs,
per_device_train_batch_size=FLAGS.per_device_train_batch_size,
per_device_eval_batch_size=FLAGS.per_device_train_batch_size,
save_steps=FLAGS.save_steps,
# save_steps=FLAGS.save_steps,
save_total_limit=FLAGS.save_total_limit,
fp16=torch.cuda.is_available(), # fp16 only works on CUDA devices
)

callbacks = [
EarlyStoppingCallback(early_stopping_patience=FLAGS.early_stopping_patience)
EarlyStoppingCallback(early_stopping_patience=FLAGS.early_stopping_patience),
]

if FLAGS.cloud_directory is not None:
# check if remote directory exists, pull down
fs = s3fs.S3FileSystem()
if fs.exists(FLAGS.cloud_directory):
subprocess.check_call(
[
"aws",
"s3",
"sync",
FLAGS.cloud_directory,
run_dir,
]
)
callbacks.append(
AwsS3Callback(local_directory=run_dir, s3_directory=FLAGS.cloud_directory)
)

trainer = create_trainer(
FLAGS.model_type, model_config, training_args, dataset_args, callbacks
)
Expand All @@ -119,7 +141,16 @@ def main(argv):
yaml.dump(flags_dict, f)
print(f"Saved command-line flags to {flags_file_path}")

trainer.train()
# if there is a checkpoint available, use it
checkpoints = glob.glob(os.path.join(run_dir, "checkpoint-*"))
if checkpoints:
iters = [int(x.split("-")[-1]) for x in checkpoints if "checkpoint" in x]
iters.sort()
latest_checkpoint = os.path.join(run_dir, f"checkpoint-{iters[-1]}")
print(f"Loading model from latest checkpoint: {latest_checkpoint}")
trainer.train(resume_from_checkpoint=latest_checkpoint)
else:
trainer.train()
trainer.save_model(os.path.join(run_dir, "final"))


Expand Down
23 changes: 23 additions & 0 deletions chemberta/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import subprocess
from dataclasses import dataclass
from typing import List

Expand All @@ -9,6 +10,7 @@
RobertaForSequenceClassification,
RobertaTokenizerFast,
Trainer,
TrainerCallback,
)

from chemberta.utils.data_collators import multitask_data_collator
Expand Down Expand Up @@ -218,3 +220,24 @@ def create_train_test_split(dataset, frac_train):
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])
return train_dataset, eval_dataset


class AwsS3Callback(TrainerCallback):
def __init__(self, local_directory, s3_directory):
self.local_directory = local_directory
self.s3_directory = s3_directory

def on_evaluate(self, args, state, control, **kwargs):
# sync local and remote directories
subprocess.check_call(
[
"aws",
"s3",
"sync",
self.local_directory,
self.s3_directory,
"--acl",
"bucket-owner-full-control",
]
)
return

0 comments on commit 9ab2f2d

Please sign in to comment.