forked from horovod/horovod
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_torch.py
789 lines (677 loc) · 31.1 KB
/
test_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import itertools
import os
import tempfile
import torch
import torch.nn.functional as F
import unittest
import numpy as np
import horovod.torch as hvd
from common import mpi_env_rank_and_size
class TorchTests(unittest.TestCase):
"""
Tests for ops in horovod.torch.
"""
def test_horovod_rank(self):
"""Test that the rank returned by hvd.rank() is correct."""
true_rank, _ = mpi_env_rank_and_size()
hvd.init()
rank = hvd.rank()
assert true_rank == rank
def test_horovod_size(self):
"""Test that the size returned by hvd.size() is correct."""
_, true_size = mpi_env_rank_and_size()
hvd.init()
size = hvd.size()
assert true_size == size
def test_horovod_allreduce(self):
"""Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
hvd.init()
size = hvd.size()
dtypes = [torch.IntTensor, torch.LongTensor,
torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
summed = hvd.allreduce(tensor, average=False)
multiplied = tensor * size
max_difference = summed.data.sub(multiplied).max()
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor]:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_average(self):
"""Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
hvd.init()
size = hvd.size()
dtypes = [torch.IntTensor, torch.LongTensor,
torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
averaged = hvd.allreduce(tensor, average=True)
max_difference = averaged.data.sub(tensor).max()
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor]:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_inplace(self):
"""Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
hvd.init()
size = hvd.size()
dtypes = [torch.IntTensor, torch.LongTensor,
torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
multiplied = tensor * size
hvd.allreduce_(tensor, average=False)
max_difference = tensor.sub(multiplied).max()
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor]:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_async_fused(self):
"""Test that the allreduce correctly sums 1D, 2D, 3D tensors
with Tensor Fusion."""
hvd.init()
size = hvd.size()
dtypes = [torch.IntTensor, torch.LongTensor,
torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
tests = []
is_hvd_poll_false_once = False
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
handle = hvd.allreduce_async(tensor, average=False)
if not hvd.poll(handle):
is_hvd_poll_false_once = True
multiplied = tensor * size
tests.append((dtype, multiplied, handle))
# Make sure it's an asynchronous operation.
assert is_hvd_poll_false_once, 'hvd.poll() always returns True, not an async op?'
for dtype, multiplied, handle in tests:
summed = hvd.synchronize(handle)
max_difference = summed.sub(multiplied).max()
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor]:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_multi_gpu(self):
"""Test that the allreduce works on multiple GPUs."""
# Only do this test if there are GPUs available.
if not torch.cuda.is_available():
return
hvd.init()
local_rank = hvd.local_rank()
size = hvd.size()
iter = 0
dtypes = [torch.cuda.IntTensor, torch.cuda.LongTensor,
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
iter += 1
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
device = local_rank * 2 + (iter + local_rank) % 2
tensor = tensor.cuda(device).type(dtype)
multiplied = tensor * size
hvd.allreduce_(tensor, average=False)
max_difference = tensor.sub(multiplied).max()
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [torch.cuda.IntTensor, torch.cuda.LongTensor]:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_error(self):
"""Test that the allreduce raises an error if different ranks try to
send tensors of different rank or dimension."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
# Same rank, different dimension
torch.manual_seed(1234)
dims = [17 + rank] * 3
tensor = torch.FloatTensor(*dims).random_(-100, 100)
try:
hvd.allreduce(tensor)
assert False, 'hvd.allreduce did not throw error'
except (torch.FatalError, RuntimeError):
pass
# Same number of elements, different rank
torch.manual_seed(1234)
if rank == 0:
dims = [17, 23 * 57]
else:
dims = [17, 23, 57]
tensor = torch.FloatTensor(*dims).random_(-100, 100)
try:
hvd.allreduce(tensor)
assert False, 'hvd.allreduce did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_allreduce_type_error(self):
"""Test that the allreduce raises an error if different ranks try to
send tensors of different type."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
# Same rank, different dimension
dims = [17] * 3
if rank % 2 == 0:
tensor = torch.IntTensor(*dims)
else:
tensor = torch.FloatTensor(*dims)
try:
hvd.allreduce(tensor)
assert False, 'hvd.allreduce did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_allreduce_cpu_gpu_error(self):
"""Test that the allreduce raises an error if different ranks try to
perform reduction on CPU and GPU."""
# Only do this test if there are GPUs available.
if not torch.cuda.is_available():
return
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
# Same rank, different dimension
dims = [17] * 3
if rank % 2 == 0:
tensor = torch.cuda.FloatTensor(*dims)
else:
tensor = torch.FloatTensor(*dims)
try:
hvd.allreduce(tensor)
assert False, 'hvd.allreduce did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_allreduce_grad(self):
"""Test the correctness of the allreduce gradient."""
hvd.init()
size = hvd.size()
# Only Tensors of floating point dtype can require gradients
dtypes = [torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
tensor = torch.autograd.Variable(tensor, requires_grad=True)
summed = hvd.allreduce(tensor, average=False)
summed.backward(torch.ones([17] * dim).type(dtype))
grad_out = tensor.grad.data.cpu().numpy()
expected = np.ones([17] * dim) * size
err = np.linalg.norm(expected - grad_out)
self.assertLess(err, 0.00000001,
"gradient %s differs from expected %s, "
"error: %s" % (grad_out, expected, str(err)))
def test_horovod_allreduce_grad_average(self):
"""Test the correctness of the allreduce averaged gradient."""
hvd.init()
# Only Tensors of floating point dtype can require gradients
dtypes = [torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
torch.manual_seed(1234)
tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
tensor = tensor.type(dtype)
tensor = torch.autograd.Variable(tensor, requires_grad=True)
summed = hvd.allreduce(tensor, average=True)
summed.backward(torch.ones([17] * dim).type(dtype))
grad_out = tensor.grad.data.cpu().numpy()
expected = np.ones([17] * dim)
err = np.linalg.norm(expected - grad_out)
self.assertLess(err, 0.00000001,
"gradient %s differs from expected %s, "
"error: %s" % (grad_out, expected, str(err)))
def test_horovod_allgather(self):
"""Test that the allgather correctly gathers 1D, 2D, 3D tensors."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
dtypes = [torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor,
torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank)
tensor = tensor.type(dtype)
gathered = hvd.allgather(tensor)
assert list(gathered.shape) == [17 * size] + [17] * (dim - 1)
for i in range(size):
rank_tensor = gathered[i * 17:(i + 1) * 17]
assert list(rank_tensor.shape) == [17] * dim, \
'hvd.allgather produces incorrect gathered shape'
assert rank_tensor.data.min() == i, 'hvd.allgather produces incorrect gathered tensor'
assert rank_tensor.data.max() == i, 'hvd.allgather produces incorrect gathered tensor'
def test_horovod_allgather_variable_size(self):
"""Test that the allgather correctly gathers 1D, 2D, 3D tensors,
even if those tensors have different sizes along the first dim."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
dtypes = [torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor,
torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
# Support tests up to MPI Size of 35
if size > 35:
break
tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5
tensor_sizes = tensor_sizes[:size]
tensor = torch.FloatTensor(
*([tensor_sizes[rank]] + [17] * (dim - 1))).fill_(1).mul_(rank)
tensor = tensor.type(dtype)
gathered = hvd.allgather(tensor)
expected_size = sum(tensor_sizes)
assert list(gathered.shape) == [expected_size] + [17] * (dim - 1)
for i in range(size):
rank_size = [tensor_sizes[i]] + [17] * (dim - 1)
rank_tensor = gathered[sum(
tensor_sizes[:i]):sum(tensor_sizes[:i + 1])]
assert list(rank_tensor.shape) == rank_size
assert rank_tensor.data.min() == i
assert rank_tensor.data.max() == i
def test_horovod_allgather_error(self):
"""Test that the allgather returns an error if any dimension besides
the first is different among the tensors being gathered."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
tensor_size = [17] * 3
tensor_size[1] = 10 * (rank + 1)
tensor = torch.FloatTensor(*tensor_size).fill_(1).mul_(rank)
try:
hvd.allgather(tensor)
assert False, 'hvd.allgather did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_allgather_type_error(self):
"""Test that the allgather returns an error if the types being gathered
differ among the processes"""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
tensor_size = [17] * 3
if rank % 2 == 0:
tensor = torch.IntTensor(*tensor_size)
else:
tensor = torch.FloatTensor(*tensor_size)
try:
hvd.allgather(tensor)
assert False, 'hvd.allgather did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_allgather_grad(self):
"""Test the correctness of the allgather gradient."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# Only Tensors of floating point dtype can require gradients
dtypes = [torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
# Support tests up to MPI Size of 35
if size > 35:
break
tensor_sizes = [3, 2, 7, 4, 6, 8, 10] * 5
tensor_sizes = tensor_sizes[:size]
tensor = torch.FloatTensor(
*([tensor_sizes[rank]] + [17] * (dim - 1))).fill_(1).mul_(rank)
tensor = tensor.type(dtype)
tensor = torch.autograd.Variable(tensor, requires_grad=True)
grad_list = []
for r, size in enumerate(tensor_sizes):
grad_list.append(torch.ones([size] + [17] * (dim - 1)).type(dtype) * r)
grad_ys = torch.cat(grad_list, dim=0)
gathered = hvd.allgather(tensor)
gathered.backward(grad_ys)
grad_out = tensor.grad.data.cpu().numpy()
expected = np.ones(
[tensor_sizes[rank]] + [17] * (dim - 1)
) * rank * size
err = np.linalg.norm(expected - grad_out)
self.assertLess(err, 0.00000001,
"gradient %s differs from expected %s, "
"error: %s" % (grad_out, expected, str(err)))
def test_horovod_broadcast(self):
"""Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
dtypes = [torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor,
torch.cuda.DoubleTensor]
dims = [1, 2, 3]
root_ranks = list(range(size))
for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks):
tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank)
root_tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(root_rank)
tensor = tensor.type(dtype)
root_tensor = root_tensor.type(dtype)
broadcasted_tensor = hvd.broadcast(tensor, root_rank)
if rank != root_rank:
assert (tensor == root_tensor).max() == 0, \
'hvd.broadcast modifies source tensor'
assert (broadcasted_tensor.data == root_tensor).min() == 1, \
'hvd.broadcast produces incorrect broadcasted tensor'
def test_horovod_broadcast_inplace(self):
"""Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
dtypes = [torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor,
torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor,
torch.cuda.DoubleTensor]
dims = [1, 2, 3]
root_ranks = list(range(size))
for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks):
tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank)
root_tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(root_rank)
tensor = tensor.type(dtype)
root_tensor = root_tensor.type(dtype)
broadcasted_tensor = hvd.broadcast_(tensor, root_rank)
assert (tensor == broadcasted_tensor).min() == 1, \
'hvd.broadcast does not modify source tensor'
assert (broadcasted_tensor == root_tensor).min() == 1, \
'hvd.broadcast produces incorrect broadcasted tensor'
def test_horovod_broadcast_error(self):
"""Test that the broadcast returns an error if any dimension besides
the first is different among the tensors being broadcasted."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
tensor_size = [17] * 3
tensor_size[1] = 10 * (rank + 1)
tensor = torch.FloatTensor(*tensor_size).fill_(1).mul_(rank)
try:
hvd.broadcast(tensor, 0)
assert False, 'hvd.broadcast did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_broadcast_type_error(self):
"""Test that the broadcast returns an error if the types being broadcasted
differ among the processes"""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
tensor_size = [17] * 3
if rank % 2 == 0:
tensor = torch.IntTensor(*tensor_size)
else:
tensor = torch.FloatTensor(*tensor_size)
try:
hvd.broadcast(tensor, 0)
assert False, 'hvd.broadcast did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_broadcast_rank_error(self):
"""Test that the broadcast returns an error if different ranks
specify different root rank."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
tensor = torch.FloatTensor(*([17] * 3)).fill_(1)
try:
hvd.broadcast(tensor, rank)
assert False, 'hvd.broadcast did not throw error'
except (torch.FatalError, RuntimeError):
pass
def test_horovod_broadcast_grad(self):
"""Test the correctness of the broadcast gradient."""
hvd.init()
rank = hvd.rank()
size = hvd.size()
# This test does not apply if there is only one worker.
if size == 1:
return
# Only Tensors of floating point dtype can require gradients
dtypes = [torch.FloatTensor, torch.DoubleTensor]
if torch.cuda.is_available():
dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
dims = [1, 2, 3]
root_ranks = list(range(size))
for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks):
tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank)
tensor = tensor.type(dtype)
tensor = torch.autograd.Variable(tensor, requires_grad=True)
broadcasted_tensor = hvd.broadcast(tensor, root_rank)
broadcasted_tensor.backward(torch.ones([17] * dim).type(dtype))
grad_out = tensor.grad.data.cpu().numpy()
c = size if rank == root_rank else 0
expected = np.ones([17] * dim) * c
err = np.linalg.norm(expected - grad_out)
self.assertLess(err, 0.00000001,
"gradient %s differs from expected %s, "
"error: %s" % (grad_out, expected, str(err)))
def test_broadcast_state(self):
hvd.init()
N, D_in, H, D_out = 64, 100, 10, 10
x = torch.autograd.Variable(torch.randn(N, D_in), requires_grad=True)
y = torch.autograd.Variable(torch.randn(N, D_out), requires_grad=False)
def create_model(create_opt):
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = create_opt(model)
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters())
return model, optimizer
def get_model_param_values(model):
params = sorted(model.state_dict().items())
return [(k, v.clone()) for k, v in params]
def get_optimizer_param_values(optimizer):
results = []
state_dict = optimizer.state_dict()
for group in state_dict['param_groups']:
for param_id in group['params']:
params = sorted(state_dict['state'][param_id].items())
for k, v in params:
results.append(
(k, v.clone() if torch.is_tensor(v) else v))
return results
opt_params = dict(lr=0.2, momentum=0.9, weight_decay=0.1, centered=True)
def new_optimizer(cls):
p = {
k: v for k, v in opt_params.items()
if k in inspect.getargspec(cls.__init__).args
}
return lambda m: cls(m.parameters(), **p)
# L-BFGS is currently unsupported, as are sparse tensors, which are
# required by SparseAdam optimizer
optimizers = [
(subclass.__name__, new_optimizer(subclass))
for subclass in torch.optim.Optimizer.__subclasses__()
if subclass.__module__.startswith('torch.optim') and
subclass != torch.optim.LBFGS and
subclass != torch.optim.SparseAdam
]
optimizers.sort()
for opt_name, create_opt in optimizers:
model, optimizer = create_model(create_opt)
y_pred = model(x)
loss = F.mse_loss(y_pred, y, size_average=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_param_values = get_model_param_values(model)
for name, model_param_value in model_param_values:
hvd.broadcast_(model_param_value, root_rank=0)
opt_param_values_updated = []
opt_param_values = get_optimizer_param_values(optimizer)
for name, opt_param_value in opt_param_values:
is_tensor = torch.is_tensor(opt_param_value)
if not is_tensor:
t = type(opt_param_value)
opt_param_value = torch.Tensor([opt_param_value])
hvd.broadcast_(opt_param_value, root_rank=0)
if not is_tensor:
opt_param_value = t(opt_param_value.cpu().numpy()[0])
opt_param_values_updated.append((name, opt_param_value))
opt_param_values = opt_param_values_updated
if hvd.rank() == 0:
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
_, fname = tempfile.mkstemp('.pt')
torch.save(state, fname)
model, optimizer = create_model(create_opt)
if hvd.rank() == 0:
checkpoint = torch.load(fname)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
os.remove(fname)
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
model_param_value_after = get_model_param_values(model)
for before, after in zip(model_param_values,
model_param_value_after):
name, model_param_value = before
name_after, model_param_value_after = after
self.assertEqual(name, name_after)
self.assertEqual(type(model_param_value),
type(model_param_value_after))
self.assertTrue(
(model_param_value == model_param_value_after).all())
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
self.assertEqual(len(optimizer.state_dict()['state'].values()), 4)
opt_param_values_after = get_optimizer_param_values(optimizer)
for before, after in zip(opt_param_values, opt_param_values_after):
name, opt_param_value = before
name_after, opt_param_value_after = after
self.assertEqual(name, name_after)
self.assertEqual(type(opt_param_value),
type(opt_param_value_after))
if torch.is_tensor(opt_param_value):
self.assertTrue(
(opt_param_value == opt_param_value_after).all())
else:
self.assertEqual(opt_param_value, opt_param_value_after)