@@ -1107,6 +1107,98 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
1107
1107
f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1108
1108
)
1109
1109
1110
+ @unittest .skipIf (
1111
+ torch_device != "cuda" or not is_accelerate_available () or is_accelerate_version ("<" , "0.17.0" ),
1112
+ reason = "CPU offload is only available with CUDA and `accelerate v0.17.0` or higher" ,
1113
+ )
1114
+ def test_cpu_offload_forward_pass_twice (self , expected_max_diff = 2e-4 ):
1115
+ import accelerate
1116
+
1117
+ generator_device = "cpu"
1118
+ components = self .get_dummy_components ()
1119
+ pipe = self .pipeline_class (** components )
1120
+
1121
+ for component in pipe .components .values ():
1122
+ if hasattr (component , "set_default_attn_processor" ):
1123
+ component .set_default_attn_processor ()
1124
+
1125
+ pipe .set_progress_bar_config (disable = None )
1126
+
1127
+ pipe .enable_model_cpu_offload ()
1128
+ inputs = self .get_dummy_inputs (generator_device )
1129
+ output_with_offload = pipe (** inputs )[0 ]
1130
+
1131
+ pipe .enable_model_cpu_offload ()
1132
+ inputs = self .get_dummy_inputs (generator_device )
1133
+ output_with_offload_twice = pipe (** inputs )[0 ]
1134
+
1135
+ max_diff = np .abs (to_np (output_with_offload ) - to_np (output_with_offload_twice )).max ()
1136
+ self .assertLess (
1137
+ max_diff , expected_max_diff , "running CPU offloading 2nd time should not affect the inference results"
1138
+ )
1139
+ offloaded_modules = [
1140
+ v
1141
+ for k , v in pipe .components .items ()
1142
+ if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1143
+ ]
1144
+ (
1145
+ self .assertTrue (all (v .device .type == "cpu" for v in offloaded_modules )),
1146
+ f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1147
+ )
1148
+
1149
+ offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1150
+ (
1151
+ self .assertTrue (all (isinstance (v , accelerate .hooks .CpuOffload ) for v in offloaded_modules_with_hooks )),
1152
+ f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .CpuOffload )]} " ,
1153
+ )
1154
+
1155
+ @unittest .skipIf (
1156
+ torch_device != "cuda" or not is_accelerate_available () or is_accelerate_version ("<" , "0.14.0" ),
1157
+ reason = "CPU offload is only available with CUDA and `accelerate v0.14.0` or higher" ,
1158
+ )
1159
+ def test_sequential_offload_forward_pass_twice (self , expected_max_diff = 2e-4 ):
1160
+ import accelerate
1161
+
1162
+ generator_device = "cpu"
1163
+ components = self .get_dummy_components ()
1164
+ pipe = self .pipeline_class (** components )
1165
+
1166
+ for component in pipe .components .values ():
1167
+ if hasattr (component , "set_default_attn_processor" ):
1168
+ component .set_default_attn_processor ()
1169
+
1170
+ pipe .set_progress_bar_config (disable = None )
1171
+
1172
+ pipe .enable_sequential_cpu_offload ()
1173
+ inputs = self .get_dummy_inputs (generator_device )
1174
+ output_with_offload = pipe (** inputs )[0 ]
1175
+
1176
+ pipe .nable_sequential_cpu_offload ()
1177
+ inputs = self .get_dummy_inputs (generator_device )
1178
+ output_with_offload_twice = pipe (** inputs )[0 ]
1179
+
1180
+ max_diff = np .abs (to_np (output_with_offload ) - to_np (output_with_offload_twice )).max ()
1181
+ self .assertLess (
1182
+ max_diff , expected_max_diff , "running sequential offloading second time should have the inference results"
1183
+ )
1184
+ offloaded_modules = [
1185
+ v
1186
+ for k , v in pipe .components .items ()
1187
+ if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1188
+ ]
1189
+ (
1190
+ self .assertTrue (all (v .device .type == "meta" for v in offloaded_modules )),
1191
+ f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'meta' ]} " ,
1192
+ )
1193
+
1194
+ offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1195
+ (
1196
+ self .assertTrue (
1197
+ all (isinstance (v , accelerate .hooks .AlignDevicesHook ) for v in offloaded_modules_with_hooks )
1198
+ ),
1199
+ f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .AlignDevicesHook )]} " ,
1200
+ )
1201
+
1110
1202
@unittest .skipIf (
1111
1203
torch_device != "cuda" or not is_xformers_available (),
1112
1204
reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
0 commit comments