Skip to content

Commit 848ed80

Browse files
authored
Improve FSDP config usability (huggingface#2288)
* Improve FSDP config usability * quality ✨ * Update tests * fix cmd arg * fix * update docs * address comments
1 parent ad957ce commit 848ed80

File tree

6 files changed

+46
-17
lines changed

6 files changed

+46
-17
lines changed

docs/source/usage_guides/fsdp.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ downcast_bf16: 'no'
4646
fsdp_config:
4747
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
4848
fsdp_backward_prefetch_policy: BACKWARD_PRE
49-
fsdp_forward_prefetch: true
49+
fsdp_forward_prefetch: false
5050
fsdp_cpu_ram_efficient_loading: true
5151
fsdp_offload_params: false
52-
fsdp_sharding_strategy: 1
52+
fsdp_sharding_strategy: FULL_SHARD
5353
fsdp_state_dict_type: SHARDED_STATE_DICT
5454
fsdp_sync_module_states: true
5555
fsdp_transformer_layer_cls_to_wrap: BertLayer

src/accelerate/commands/config/cluster.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def get_cluster_input():
327327
fsdp_config["fsdp_sharding_strategy"] = _ask_options(
328328
sharding_strategy_query,
329329
FSDP_SHARDING_STRATEGY,
330-
lambda x: int(x) + 1,
331-
default=1,
330+
lambda x: FSDP_SHARDING_STRATEGY[int(x)],
332331
)
333332
fsdp_config["fsdp_offload_params"] = _ask_field(
334333
"Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
@@ -362,7 +361,7 @@ def get_cluster_input():
362361
default=100000000,
363362
)
364363
fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
365-
fsdp_config["fsdp_backward_prefetch_policy"] = _ask_options(
364+
fsdp_config["fsdp_backward_prefetch"] = _ask_options(
366365
fsdp_backward_prefetch_query,
367366
FSDP_BACKWARD_PREFETCH,
368367
lambda x: FSDP_BACKWARD_PREFETCH[int(x)],

src/accelerate/commands/launch.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,8 @@ def launch_command_parser(subparsers=None):
482482
)
483483
fsdp_args.add_argument(
484484
"--fsdp_sharding_strategy",
485-
type=int,
486-
default=1,
485+
type=str,
486+
default="FULL_SHARD",
487487
help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).",
488488
)
489489
fsdp_args.add_argument(
@@ -503,6 +503,12 @@ def launch_command_parser(subparsers=None):
503503
"--fsdp_backward_prefetch_policy",
504504
default=None,
505505
type=str,
506+
help="This argument is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use `fsdp_backward_prefetch` instead.",
507+
)
508+
fsdp_args.add_argument(
509+
"--fsdp_backward_prefetch",
510+
default=None,
511+
type=str,
506512
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
507513
)
508514
fsdp_args.add_argument(

src/accelerate/utils/dataclasses.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import torch
3232

33-
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
33+
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
3434
from .environment import str_to_bool
3535
from .imports import is_cuda_available, is_npu_available, is_xpu_available
3636
from .versions import compare_versions
@@ -439,6 +439,7 @@ class CustomDtype(enum.Enum):
439439
r"""
440440
An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`.
441441
"""
442+
442443
FP8 = "fp8"
443444
INT4 = "int4"
444445

@@ -918,7 +919,7 @@ class FullyShardedDataParallelPlugin:
918919
},
919920
)
920921
limit_all_gathers: bool = field(
921-
default=False,
922+
default=True,
922923
metadata={
923924
"help": "If False, then FSDP allows the CPU thread to schedule all-gathers "
924925
"without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent "
@@ -929,9 +930,10 @@ class FullyShardedDataParallelPlugin:
929930
use_orig_params: bool = field(
930931
default=True,
931932
metadata={
932-
"help": "If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. "
933+
"help": "If `True`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. "
933934
"Useful in cases such as parameter-efficient fine-tuning. "
934-
"Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)"
935+
"Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). "
936+
"This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP."
935937
},
936938
)
937939
param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field(
@@ -969,7 +971,13 @@ def __post_init__(self):
969971

970972
prefix = "FSDP_"
971973
if self.sharding_strategy is None:
972-
self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1)))
974+
sharding_strategy = os.environ.get(prefix + "SHARDING_STRATEGY", "FULL_SHARD")
975+
sharding_strategy = (
976+
FSDP_SHARDING_STRATEGY.index(sharding_strategy) + 1
977+
if not sharding_strategy.isdigit()
978+
else int(sharding_strategy)
979+
)
980+
self.sharding_strategy = ShardingStrategy(sharding_strategy)
973981

974982
if self.cpu_offload is None:
975983
if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:

src/accelerate/utils/launch.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import argparse
1616
import os
1717
import sys
18+
import warnings
1819
from ast import literal_eval
1920
from typing import Any, Dict, List, Tuple
2021

@@ -188,7 +189,14 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
188189
if args.fsdp_transformer_layer_cls_to_wrap is not None:
189190
current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
190191
if args.fsdp_backward_prefetch_policy is not None:
191-
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy)
192+
warnings.warn(
193+
"`fsdp_backward_prefetch_policy` is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use"
194+
" `fsdp_backward_prefetch` instead",
195+
FutureWarning,
196+
)
197+
args.fsdp_backward_prefetch = args.fsdp_backward_prefetch_policy
198+
if args.fsdp_backward_prefetch is not None:
199+
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch)
192200
if args.fsdp_state_dict_type is not None:
193201
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
194202
current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()

tests/fsdp/test_fsdp.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,18 @@ def setUp(self):
6969
def test_sharding_strategy(self):
7070
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
7171

72+
# check that giving enums works fine
7273
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
7374
env = self.dist_env.copy()
7475
env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}"
75-
env["FSDP_SHARDING_STRATEGY_NAME"] = strategy
76+
with mockenv_context(**env):
77+
fsdp_plugin = FullyShardedDataParallelPlugin()
78+
self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1))
79+
80+
# check that giving names works fine
81+
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
82+
env = self.dist_env.copy()
83+
env["FSDP_SHARDING_STRATEGY"] = strategy
7684
with mockenv_context(**env):
7785
fsdp_plugin = FullyShardedDataParallelPlugin()
7886
self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1))
@@ -201,7 +209,7 @@ def test_performance(self):
201209
cmd_config = cmd.copy()
202210
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
203211
if strategy.lower() in config:
204-
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
212+
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
205213
break
206214

207215
if "fp32" in config:
@@ -247,7 +255,7 @@ def test_checkpointing(self):
247255

248256
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
249257
cmd_config = cmd.copy()
250-
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
258+
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
251259
if strategy != "FULL_SHARD":
252260
continue
253261
state_dict_config_index = len(cmd_config)
@@ -301,7 +309,7 @@ def test_peak_memory_usage(self):
301309
cmd_config.extend(["--use_fsdp"])
302310
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
303311
if strategy.lower() in spec:
304-
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
312+
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
305313
break
306314

307315
if "cpu_offload" in spec:

0 commit comments

Comments
 (0)