Skip to content

Commit

Permalink
accelerate deepspeed and gradient accumulation integrate (huggingface…
Browse files Browse the repository at this point in the history
…#23236)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate

* shift deepspeed integration and save & load utils to accelerate

* fix accelerate launcher support

* oops

* fix 🐛

* save ckpt fix

* Trigger CI

* nasty 🐛 😅

* as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate

* make tests happy

* quality ✨

* loss tracked needs to account for grad_acc

* fixing the deepspeed tests

* quality ✨

* 😅😅😅

* tests 😡

* quality ✨

* Trigger CI

* resolve comments and fix the issue with the previous merge from branch

* Trigger CI

* accelerate took over deepspeed integration

---------

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
pacman100 and stas00 authored May 31, 2023
1 parent 88f50a1 commit a73b1d5
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 164 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ body:
Integrations:
- deepspeed: HF Trainer: @stas00, Accelerate: @pacman100
- deepspeed: HF Trainer/Accelerate: @pacman100
- ray/raytune: @richardliaw, @amogkam
- Big Model Inference: @sgugger @muellerzr
Expand Down
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Library:
Integrations:
- deepspeed: HF Trainer: @stas00, Accelerate: @pacman100
- deepspeed: HF Trainer/Accelerate: @pacman100
- ray/raytune: @richardliaw, @amogkam
Documentation: @sgugger, @stevhliu and @MKhalusova
Expand Down
95 changes: 42 additions & 53 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import importlib.util
import weakref
from copy import deepcopy
from functools import partialmethod

from .dependency_versions_check import dep_version_check
Expand Down Expand Up @@ -256,24 +255,26 @@ def deepspeed_config():
return None


def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps):
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
"""
from accelerate.utils import DummyOptim, DummyScheduler

config = hf_deepspeed_config.config

# Optimizer + Scheduler
# Currently supported combos:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes
# 4. HF scheduler + DS optimizer: Yes
# 4. HF scheduler + DS optimizer: No
#
# Unless Offload is enabled in which case it's:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: Yes
# 4. HF scheduler + DS optimizer: No
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)

Expand All @@ -284,6 +285,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
"--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
"Only one optimizer can be configured."
)
optimizer = DummyOptim(params=model_parameters)
else:
if hf_deepspeed_config.is_offload():
logger.info(
Expand All @@ -297,21 +299,21 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
config["zero_allow_untested_optimizer"] = True

def _lr_scheduler_callable(optimizer):
return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

lr_scheduler = None
if "scheduler" not in config:
if optimizer is None:
# Optimizer is not available, so use callable to defer lr_scheduler creation to DS init
lr_scheduler = _lr_scheduler_callable
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
if "scheduler" in config:
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler


def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
def deepspeed_init(trainer, num_training_steps, inference=False):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
Expand All @@ -323,28 +325,22 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
inference: launch in inference mode (no optimizer and no lr scheduler)
Returns: model, optimizer, lr_scheduler
Returns: optimizer, lr_scheduler
We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612
"""
import deepspeed
from deepspeed.utils import logger as ds_logger

model = trainer.model
args = trainer.args

if hasattr(trainer, "hf_deepspeed_config_orig"):
hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig)
else:
hf_deepspeed_config = args.hf_deepspeed_config
trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config)
hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config

# resume config update - some bits like `model` and `num_training_steps` only become available during train
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
config = hf_deepspeed_config.config

# set the Deepspeed log level consistent with the Trainer
ds_logger.setLevel(args.get_process_log_level())
Expand All @@ -361,40 +357,33 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
model_parameters = None
else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer, lr_scheduler = deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
)

# keep for quick debug:
# from pprint import pprint; pprint(config)

kwargs = {
"model": model,
"model_parameters": model_parameters,
"config_params": config,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
}

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)

if resume_from_checkpoint is not None:
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
# path contains what looks like a deepspeed checkpoint
import glob

deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
else:
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
return optimizer, lr_scheduler


return deepspeed_engine, optimizer, lr_scheduler
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path):
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
# path contains what looks like a deepspeed checkpoint
import glob

deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {checkpoint_path}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
7 changes: 7 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@
)


if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState


SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
Expand Down Expand Up @@ -1331,6 +1335,9 @@ def tearDown(self):
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []
if is_accelerate_available():
AcceleratorState._reset_state()
PartialState._reset_state()


def mockenv(**kwargs):
Expand Down
Loading

0 comments on commit a73b1d5

Please sign in to comment.