forked from huggingface/accelerate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_big_modeling.py
486 lines (393 loc) · 18.4 KB
/
test_big_modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from tempfile import TemporaryDirectory
import torch
import torch.nn as nn
from accelerate.big_modeling import (
cpu_offload,
disk_offload,
dispatch_model,
init_empty_weights,
init_on_device,
load_checkpoint_and_dispatch,
)
from accelerate.hooks import remove_hook_from_submodules
from accelerate.test_utils import require_cuda, require_mps, require_multi_gpu, require_torch_min_version, slow
from accelerate.utils import offload_state_dict
from transformers import AutoModelForCausalLM, AutoTokenizer
class ModelForTest(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 4)
self.batchnorm = nn.BatchNorm1d(4)
self.linear2 = nn.Linear(4, 5)
def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))
class ModelForTestTiedWeights(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(4, 4)
self.batchnorm = nn.BatchNorm1d(4)
self.linear2 = nn.Linear(4, 4)
def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))
class BiggerModelForTest(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 4)
self.linear2 = nn.Linear(4, 5)
self.batchnorm = nn.BatchNorm1d(5)
self.linear3 = nn.Linear(5, 6)
self.linear4 = nn.Linear(6, 5)
def forward(self, x):
return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))
# To test preload_module_classes
class ModuleWithUnusedSubModules(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return x @ self.linear.weight.t() + self.linear.bias
class ModelWithUnusedSubModulesForTest(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = ModuleWithUnusedSubModules(3, 4)
self.linear2 = ModuleWithUnusedSubModules(4, 5)
self.batchnorm = nn.BatchNorm1d(5)
self.linear3 = ModuleWithUnusedSubModules(5, 6)
self.linear4 = ModuleWithUnusedSubModules(6, 5)
def forward(self, x):
return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))
@require_torch_min_version(version="1.9.0")
class BigModelingTester(unittest.TestCase):
def test_init_empty_weights(self):
# base use
with init_empty_weights():
module = nn.Linear(4, 5)
self.assertEqual(module.weight.device, torch.device("meta"))
# base use with buffers, they are not touched
with init_empty_weights():
module = nn.BatchNorm1d(4)
self.assertEqual(module.weight.device, torch.device("meta"))
self.assertEqual(module.running_mean.device, torch.device("cpu"))
# Use with include_buffers=True
with init_empty_weights(include_buffers=True):
module = nn.BatchNorm1d(4)
self.assertEqual(module.weight.device, torch.device("meta"))
self.assertEqual(module.running_mean.device, torch.device("meta"))
# Double check we didn't break PyTorch
module = nn.BatchNorm1d(4)
self.assertEqual(module.weight.device, torch.device("cpu"))
self.assertEqual(module.running_mean.device, torch.device("cpu"))
def test_init_empty_weights_very_large_model(self):
# This is a 100 billion parameters model.
with init_empty_weights():
_ = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
@require_cuda
def test_init_on_device_cuda(self):
device = torch.device("cuda:0")
with init_on_device(device):
model = nn.Linear(10, 10)
self.assertEqual(model.weight.device, device)
self.assertEqual(model.weight.device, device)
@require_mps
def test_init_on_device_mps(self):
device = torch.device("mps:0")
with init_on_device(device):
model = nn.Linear(10, 10)
self.assertEqual(model.weight.device, device)
self.assertEqual(model.weight.device, device)
def test_cpu_offload(self):
model = ModelForTest()
x = torch.randn(2, 3)
expected = model(x)
device = torch.device(0 if torch.cuda.is_available() else "cpu")
cpu_offload(model, execution_device=device)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
# Clean up for next test.
remove_hook_from_submodules(model)
cpu_offload(model, execution_device=device, offload_buffers=True)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
def test_cpu_offload_with_unused_submodules(self):
model = ModelWithUnusedSubModulesForTest()
x = torch.randn(2, 3)
expected = model(x)
device = torch.device(0 if torch.cuda.is_available() else "cpu")
cpu_offload(model, execution_device=device, preload_module_classes=["ModuleWithUnusedSubModules"])
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
# Clean up for next test.
remove_hook_from_submodules(model)
cpu_offload(
model,
execution_device=device,
offload_buffers=True,
preload_module_classes=["ModuleWithUnusedSubModules"],
)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
@slow
@require_cuda
def test_cpu_offload_gpt2(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
cpu_offload(gpt2, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
self.assertEqual(
tokenizer.decode(outputs[0].tolist()),
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
)
def test_disk_offload(self):
model = ModelForTest()
x = torch.randn(2, 3)
expected = model(x)
device = torch.device(0 if torch.cuda.is_available() else "cpu")
with TemporaryDirectory() as tmp_dir:
disk_offload(model, tmp_dir, execution_device=device)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
# Clean up for next test.
remove_hook_from_submodules(model)
with TemporaryDirectory() as tmp_dir:
disk_offload(model, tmp_dir, execution_device=device, offload_buffers=True)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
def test_disk_offload_with_unused_submodules(self):
model = ModelWithUnusedSubModulesForTest()
x = torch.randn(2, 3)
expected = model(x)
device = torch.device(0 if torch.cuda.is_available() else "cpu")
with TemporaryDirectory() as tmp_dir:
disk_offload(
model, tmp_dir, execution_device=device, preload_module_classes=["ModuleWithUnusedSubModules"]
)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
# Clean up for next test.
remove_hook_from_submodules(model)
with TemporaryDirectory() as tmp_dir:
disk_offload(
model,
tmp_dir,
execution_device=device,
offload_buffers=True,
preload_module_classes=["ModuleWithUnusedSubModules"],
)
output = model(x)
self.assertTrue(
torch.allclose(expected, output.cpu(), 1e-4, 1e-5), msg=f"Expected: {expected}\nActual: {output.cpu()}"
)
@slow
@require_cuda
def test_disk_offload_gpt2(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
with TemporaryDirectory() as tmp_dir:
disk_offload(gpt2, tmp_dir, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
self.assertEqual(
tokenizer.decode(outputs[0].tolist()),
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
)
@require_cuda
def test_dispatch_model(self):
model = ModelForTest()
device_map = {"linear1": "disk", "batchnorm": "cpu", "linear2": 0}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir)
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_cuda
def test_dispatch_model_tied_weights(self):
model = ModelForTestTiedWeights()
model.linear1.weight = model.linear2.weight
device_map = {"linear1": 0, "batchnorm": 0, "linear2": 0}
dispatch_model(model, device_map)
self.assertIs(model.linear2.weight, model.linear1.weight)
@require_multi_gpu
def test_dispatch_model_multi_gpu(self):
model = BiggerModelForTest()
device_map = {"linear1": "cpu", "linear2": "disk", "batchnorm": "cpu", "linear3": 0, "linear4": 1}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir)
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@slow
@require_multi_gpu
def test_dispatch_model_gpt2_on_two_gpus(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
# Dispatch on GPUs 0 and 1
device_map = {
"transformer.wte": 0,
"transformer.wpe": 0,
"transformer.ln_f": 1,
"lm_head": 0,
}
for i in range(12):
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
self.assertEqual(
tokenizer.decode(outputs[0].tolist()),
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
)
# Dispatch with a bit of CPU offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
for i in range(4):
device_map[f"transformer.h.{i}"] = "cpu"
gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
self.assertEqual(
tokenizer.decode(outputs[0].tolist()),
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
)
# Dispatch with a bit of CPU and disk offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
for i in range(2):
device_map[f"transformer.h.{i}"] = "disk"
with TemporaryDirectory() as tmp_dir:
state_dict = {
k: p for k, p in gpt2.state_dict().items() if "transformer.h.0" in k or "transformer.h.1" in k
}
offload_state_dict(tmp_dir, state_dict)
gpt2 = dispatch_model(gpt2, device_map, offload_dir=tmp_dir)
outputs = gpt2.generate(inputs["input_ids"])
self.assertEqual(
tokenizer.decode(outputs[0].tolist()),
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
)
@require_cuda
def test_dispatch_model_with_unused_submodules(self):
model = ModelWithUnusedSubModulesForTest()
device_map = {"linear1": "cpu", "linear2": "disk", "batchnorm": "cpu", "linear3": 0, "linear4": 0}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
dispatch_model(
model, device_map, offload_dir=tmp_dir, preload_module_classes=["ModuleWithUnusedSubModules"]
)
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_multi_gpu
def test_dispatch_model_with_unused_submodules_multi_gpu(self):
model = ModelWithUnusedSubModulesForTest()
device_map = {"linear1": "cpu", "linear2": "disk", "batchnorm": "cpu", "linear3": 0, "linear4": 1}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
dispatch_model(
model, device_map, offload_dir=tmp_dir, preload_module_classes=["ModuleWithUnusedSubModules"]
)
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_cuda
def test_load_checkpoint_and_dispatch(self):
model = ModelForTest()
device_map = {"linear1": "cpu", "batchnorm": "cpu", "linear2": 0}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
torch.save(model.state_dict(), checkpoint)
new_model = ModelForTest()
new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
self.assertEqual(new_model.linear1.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear2.weight.device, torch.device(0))
output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_multi_gpu
def test_load_checkpoint_and_dispatch_multi_gpu(self):
model = BiggerModelForTest()
device_map = {"linear1": "cpu", "linear2": "cpu", "batchnorm": 0, "linear3": 0, "linear4": 1}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
torch.save(model.state_dict(), checkpoint)
new_model = BiggerModelForTest()
new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
self.assertEqual(new_model.linear1.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear2.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear3.weight.device, torch.device(0))
self.assertEqual(new_model.linear4.weight.device, torch.device(1))
output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_cuda
def test_load_checkpoint_and_dispatch_with_unused_submodules(self):
model = ModelWithUnusedSubModulesForTest()
device_map = {"linear1": "cpu", "linear2": "cpu", "batchnorm": 0, "linear3": 0, "linear4": 0}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
torch.save(model.state_dict(), checkpoint)
new_model = ModelWithUnusedSubModulesForTest()
new_model = load_checkpoint_and_dispatch(
new_model, checkpoint, device_map=device_map, preload_module_classes=["ModuleWithUnusedSubModules"]
)
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
self.assertEqual(new_model.linear1.linear.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear2.linear.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear3.linear.weight.device, torch.device(0))
self.assertEqual(new_model.linear4.linear.weight.device, torch.device(0))
output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_multi_gpu
def test_load_checkpoint_and_dispatch_multi_gpu_with_unused_submodules(self):
model = ModelWithUnusedSubModulesForTest()
device_map = {"linear1": "cpu", "linear2": "cpu", "batchnorm": 0, "linear3": 0, "linear4": 1}
x = torch.randn(2, 3)
expected = model(x)
with TemporaryDirectory() as tmp_dir:
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
torch.save(model.state_dict(), checkpoint)
new_model = ModelWithUnusedSubModulesForTest()
new_model = load_checkpoint_and_dispatch(
new_model, checkpoint, device_map=device_map, preload_module_classes=["ModuleWithUnusedSubModules"]
)
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
self.assertEqual(new_model.linear1.linear.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear2.linear.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear3.linear.weight.device, torch.device(0))
self.assertEqual(new_model.linear4.linear.weight.device, torch.device(1))
output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))