Skip to content

Commit

Permalink
Adding optional trial argument to model_init (huggingface#7759)
Browse files Browse the repository at this point in the history
* Adding optional trial argument to model_init

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
madlag and sgugger authored Oct 13, 2020
1 parent 7e73c12 commit 2d6e2ad
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class Trainer:
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose
different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
Expand Down Expand Up @@ -212,15 +215,16 @@ def __init__(
assert (
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
if model is None and model_init is not None:
model = model_init()
model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.model_init = model_init

self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
Expand Down Expand Up @@ -532,6 +536,17 @@ def _tune_save_checkpoint(self):
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

def call_model_init(self, trial=None):
model_init_argcount = len(inspect.signature(self.model_init).parameters)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise Exception("model_init should have 0 or 1 argument.")

return model

def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
"""
Main training entry point.
Expand All @@ -550,7 +565,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
model = self.model_init()

model = self.call_model_init(trial)

self.model = model.to(self.args.device)

# Reinitializes optimizer and scheduler
Expand Down

0 comments on commit 2d6e2ad

Please sign in to comment.