Skip to content

Commit f238cb0

Browse files
yiyixuxupcuenca
andauthored
cpu_offload: remove all hooks before offload (huggingface#7448)
* add remove_all_hooks * a few more fix and tests * up * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * split tests * add --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent d78acde commit f238cb0

File tree

2 files changed

+110
-14
lines changed

2 files changed

+110
-14
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,7 @@ def module_is_sequentially_offloaded(module):
371371
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
372372
return False
373373

374-
return hasattr(module, "_hf_hook") and not isinstance(
375-
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
376-
)
374+
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
377375

378376
def module_is_offloaded(module):
379377
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
@@ -939,6 +937,16 @@ def _execution_device(self):
939937
return torch.device(module._hf_hook.execution_device)
940938
return self.device
941939

940+
def remove_all_hooks(self):
941+
r"""
942+
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
943+
"""
944+
for _, model in self.components.items():
945+
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
946+
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
947+
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
948+
self._all_hooks = []
949+
942950
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
943951
r"""
944952
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -963,6 +971,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
963971
else:
964972
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
965973

974+
self.remove_all_hooks()
975+
966976
torch_device = torch.device(device)
967977
device_index = torch_device.index
968978

@@ -979,15 +989,13 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
979989
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
980990
self._offload_device = device
981991

982-
if self.device.type != "cpu":
983-
self.to("cpu", silence_dtype_warnings=True)
984-
device_mod = getattr(torch, self.device.type, None)
985-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
986-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
992+
self.to("cpu", silence_dtype_warnings=True)
993+
device_mod = getattr(torch, device.type, None)
994+
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
995+
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
987996

988997
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
989998

990-
self._all_hooks = []
991999
hook = None
9921000
for model_str in self.model_cpu_offload_seq.split("->"):
9931001
model = all_model_components.pop(model_str, None)
@@ -1021,11 +1029,6 @@ def maybe_free_model_hooks(self):
10211029
# `enable_model_cpu_offload` has not be called, so silently do nothing
10221030
return
10231031

1024-
for hook in self._all_hooks:
1025-
# offload model and remove hook from model
1026-
hook.offload()
1027-
hook.remove()
1028-
10291032
# make sure the model is in the same state as before calling it
10301033
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
10311034

@@ -1048,6 +1051,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
10481051
from accelerate import cpu_offload
10491052
else:
10501053
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1054+
self.remove_all_hooks()
10511055

10521056
torch_device = torch.device(device)
10531057
device_index = torch_device.index

tests/pipelines/test_pipelines_common.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,98 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
11071107
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
11081108
)
11091109

1110+
@unittest.skipIf(
1111+
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
1112+
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
1113+
)
1114+
def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1115+
import accelerate
1116+
1117+
generator_device = "cpu"
1118+
components = self.get_dummy_components()
1119+
pipe = self.pipeline_class(**components)
1120+
1121+
for component in pipe.components.values():
1122+
if hasattr(component, "set_default_attn_processor"):
1123+
component.set_default_attn_processor()
1124+
1125+
pipe.set_progress_bar_config(disable=None)
1126+
1127+
pipe.enable_model_cpu_offload()
1128+
inputs = self.get_dummy_inputs(generator_device)
1129+
output_with_offload = pipe(**inputs)[0]
1130+
1131+
pipe.enable_model_cpu_offload()
1132+
inputs = self.get_dummy_inputs(generator_device)
1133+
output_with_offload_twice = pipe(**inputs)[0]
1134+
1135+
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
1136+
self.assertLess(
1137+
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
1138+
)
1139+
offloaded_modules = [
1140+
v
1141+
for k, v in pipe.components.items()
1142+
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1143+
]
1144+
(
1145+
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
1146+
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
1147+
)
1148+
1149+
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
1150+
(
1151+
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
1152+
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
1153+
)
1154+
1155+
@unittest.skipIf(
1156+
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
1157+
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
1158+
)
1159+
def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1160+
import accelerate
1161+
1162+
generator_device = "cpu"
1163+
components = self.get_dummy_components()
1164+
pipe = self.pipeline_class(**components)
1165+
1166+
for component in pipe.components.values():
1167+
if hasattr(component, "set_default_attn_processor"):
1168+
component.set_default_attn_processor()
1169+
1170+
pipe.set_progress_bar_config(disable=None)
1171+
1172+
pipe.enable_sequential_cpu_offload()
1173+
inputs = self.get_dummy_inputs(generator_device)
1174+
output_with_offload = pipe(**inputs)[0]
1175+
1176+
pipe.nable_sequential_cpu_offload()
1177+
inputs = self.get_dummy_inputs(generator_device)
1178+
output_with_offload_twice = pipe(**inputs)[0]
1179+
1180+
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
1181+
self.assertLess(
1182+
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
1183+
)
1184+
offloaded_modules = [
1185+
v
1186+
for k, v in pipe.components.items()
1187+
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1188+
]
1189+
(
1190+
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
1191+
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
1192+
)
1193+
1194+
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
1195+
(
1196+
self.assertTrue(
1197+
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
1198+
),
1199+
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
1200+
)
1201+
11101202
@unittest.skipIf(
11111203
torch_device != "cuda" or not is_xformers_available(),
11121204
reason="XFormers attention is only available with CUDA and `xformers` installed",

0 commit comments

Comments
 (0)