Skip to content

Commit

Permalink
Move distributed from utils to training (pytorch#1461)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Sep 3, 2024
1 parent 4dc1cd0 commit e959321
Show file tree
Hide file tree
Showing 17 changed files with 121 additions and 115 deletions.
18 changes: 18 additions & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ Utilities for working in a reduced precision setting.
validate_expected_param_dtype
get_quantizer_mode

.. _dist_label:

Distributed
-----------

Utilities for enabling and working with distributed training.

.. autosummary::
:toctree: generated/
:nosignatures:

FSDPPolicyType
init_distributed
is_distributed
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy


Memory Management
-----------------
Expand Down
21 changes: 1 addition & 20 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,6 @@ torchtune.utils

.. currentmodule:: torchtune.utils

.. _dist_label:

Distributed
-----------

Utilities for enabling and working with distributed training.

.. autosummary::
:toctree: generated/
:nosignatures:

FSDPPolicyType
init_distributed
is_distributed
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy

.. _ac_label:

Memory Management
Expand Down Expand Up @@ -54,15 +36,14 @@ of your finetuning job.

.. _gen_label:


Miscellaneous
-------------

.. autosummary::
:toctree: generated/
:nosignatures:

get_logger
get_device
get_logger
generate
torch_version_ge
20 changes: 10 additions & 10 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
Expand Down Expand Up @@ -363,7 +363,7 @@ def _setup_model(
fully_shard(model, **fsdp_kwargs)

if lora_weights_state_dict:
lora_missing, lora_unexpected = utils.load_from_full_model_state_dict(
lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)
else:
Expand All @@ -384,7 +384,7 @@ def _setup_model(
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()

base_missing, base_unexpected = utils.load_from_full_model_state_dict(
base_missing, base_unexpected = training.load_from_full_model_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)
is_dora = False
Expand All @@ -404,7 +404,7 @@ def _setup_model(
lora_unexpected=lora_unexpected,
)
# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)
training.validate_no_params_on_meta_device(model)

if self._is_rank_zero:
log.info(
Expand All @@ -423,7 +423,7 @@ def _setup_optimizer(
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
utils.load_from_full_optimizer_state_dict(
training.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -460,7 +460,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -518,13 +518,13 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = utils.get_full_model_state_dict(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
self._is_rank_zero,
)

if intermediate_checkpoint:
opt_state_dict = utils.get_full_optimizer_state_dict(
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
)
Expand Down Expand Up @@ -588,7 +588,7 @@ def train(self) -> None:
# clean up before training begins
utils.cleanup_before_training()

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -706,7 +706,7 @@ def recipe_main(cfg: DictConfig) -> None:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not utils.is_distributed():
if not training.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
Expand Down
22 changes: 11 additions & 11 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, cfg: DictConfig) -> None:

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -415,7 +415,7 @@ def _is_layer_fqn(s: str) -> bool:
if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]

utils.shard_model(
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
Expand All @@ -430,12 +430,12 @@ def _is_layer_fqn(s: str) -> bool:

# This method will convert the full model state dict into a sharded state
# dict and load into the model
utils.load_from_full_model_state_dict(
training.load_from_full_model_state_dict(
model, model_state_dict, self._device, self._is_rank_zero, strict=True
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)
training.validate_no_params_on_meta_device(model)

if self._is_rank_zero:
log.info(
Expand All @@ -454,7 +454,7 @@ def _setup_optimizer(
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
utils.load_from_full_optimizer_state_dict(
training.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
Expand All @@ -475,7 +475,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -529,13 +529,13 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = utils.get_full_model_state_dict(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
self._is_rank_zero,
)

if intermediate_checkpoint:
opt_state_dict = utils.get_full_optimizer_state_dict(
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
)
Expand Down Expand Up @@ -574,7 +574,7 @@ def train(self) -> None:
# clean up before training begins
utils.cleanup_before_training()

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -721,7 +721,7 @@ def recipe_main(cfg: DictConfig) -> None:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not utils.is_distributed():
if not training.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
Expand All @@ -730,7 +730,7 @@ def recipe_main(cfg: DictConfig) -> None:
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
utils.set_torch_num_threads()
training.set_torch_num_threads()

config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)

Expand Down
12 changes: 6 additions & 6 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
Expand Down Expand Up @@ -359,7 +359,7 @@ def _setup_model(

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
auto_wrap_policy=training.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerSelfAttentionLayer}
),
sharding_strategy=self._fsdp_sharding_strategy,
Expand All @@ -379,7 +379,7 @@ def _setup_model(
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)
training.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
Expand Down Expand Up @@ -437,7 +437,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -584,7 +584,7 @@ def train(self) -> None:
# clean up before training begins
utils.cleanup_before_training()

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -724,7 +724,7 @@ def recipe_main(cfg: DictConfig) -> None:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not utils.is_distributed():
if not training.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
Expand Down
12 changes: 6 additions & 6 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
Expand Down Expand Up @@ -452,7 +452,7 @@ def _setup_model(

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
auto_wrap_policy=training.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerSelfAttentionLayer}
),
sharding_strategy=self._fsdp_sharding_strategy,
Expand All @@ -472,7 +472,7 @@ def _setup_model(
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)
training.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
Expand Down Expand Up @@ -530,7 +530,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -658,7 +658,7 @@ def train(self) -> None:
# clean up before training begins
utils.cleanup_before_training()

_, rank = utils.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -804,7 +804,7 @@ def recipe_main(cfg: DictConfig) -> None:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not utils.is_distributed():
if not training.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
Expand Down
Loading

0 comments on commit e959321

Please sign in to comment.