From f35cbd1f023db2c7a4972388df3a34274cca7939 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 31 Jul 2024 14:03:23 -0600 Subject: [PATCH] Enable Unwrapping for Model State Dicts (FSDP) (#2959) Signed-off-by: Alex-Brooks --- src/accelerate/accelerator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 308861589d4..9d674c6394c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -3289,6 +3289,8 @@ def get_state_dict(self, model, unwrap=True): from torch.distributed.fsdp import FullStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + if unwrap: + model = self.unwrap_model(model) full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): state_dict = model.state_dict()