Skip to content

Commit

Permalink
Support replicating multi-GPU modules (pytorch#18687)
Browse files Browse the repository at this point in the history
Summary:
If the input `network` resides on multiple GPUs, `devices` must be a 2D list with `devices[0]` matching `network`'s devices. See  pytorch#18591
Pull Request resolved: pytorch#18687

Differential Revision: D14706162

Pulled By: mrshenli

fbshipit-source-id: dca630d3308f2dbcf8b75629c452d7a64092ba42
  • Loading branch information
mrshenli authored and facebook-github-bot committed Apr 3, 2019
1 parent eabd9ea commit 7ae0263
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 22 deletions.
1 change: 1 addition & 0 deletions test/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
TEST_GEQ4GPU = TEST_CUDA and torch.cuda.device_count() >= 4
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
Expand Down
158 changes: 146 additions & 12 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
get_function_arglist, load_tests
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \
TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, loss_reference_fns, get_reduction, \
Expand Down Expand Up @@ -3336,23 +3336,157 @@ def test_replicate(self):
module = nn.Linear(10, 5).float().cuda()
input = Variable(torch.randn(2, 10).float().cuda())
expected_output = module(input).data
replicas = dp.replicate(module, (0, 1))
for i, replica in enumerate(replicas):
for p in replica.parameters():
self.assertEqual(p.get_device(), i)
replica_input = input.cuda(i)
self.assertEqual(replica(replica_input).data, expected_output)
for devices in [(0, 1), [[0], [1]]]:
replicas = dp.replicate(module, devices)
for i, replica in enumerate(replicas):
for p in replica.parameters():
self.assertEqual(p.get_device(), i)
replica_input = input.cuda(i)
self.assertEqual(replica(replica_input).data, expected_output)

@unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs")
def test_replicate_multi_gpu_module(self):
class MultiGpuModule(nn.Module):
def __init__(self):
super(MultiGpuModule, self).__init__()
self.net1 = torch.nn.Linear(10, 5).cuda(0)
self.net2 = torch.nn.Linear(5, 5).cuda(1)
self.bn = nn.BatchNorm2d(10).cuda(0)

def forward(self, x):
out = self.net1(x.cuda(self.net1.weight.get_device()))
return self.net2(out.cuda(self.net2.weight.get_device()))

module = MultiGpuModule()

input = torch.rand(2, 10).cuda(0)
expected_output = module(input).cpu()

for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]):
replicas = dp.replicate(module, devices)
for i, replica in enumerate(replicas):
self.assertEqual(replica.net1.weight.get_device(), 2 * i)
self.assertEqual(replica.net1.bias.get_device(), 2 * i)
self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1)
self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1)
self.assertEqual(replica.bn.running_mean.get_device(), 2 * i)
self.assertEqual(replica.bn.running_var.get_device(), 2 * i)
self.assertEqual(
replica.bn.num_batches_tracked.get_device(), 2 * i)

replica_input = input.cuda(2 * i)
replica_output = replica(replica_input)
self.assertEqual(replica_output.get_device(), 2 * i + 1)
self.assertEqual(replica_output.cpu(), expected_output)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_replicate_device_indices(self):
from torch.nn.parallel.replicate import _to_device_index as f

self.assertEqual(
f([['cuda:0', 'cuda:1', 'cuda:2'],
['cuda:4', 'cuda:3', 'cuda:6']]),
[[0, 1, 2], [4, 3, 6]])

self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2])

self.assertEqual(
len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))),
3)
self.assertEqual(
f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]])

msg = "empty device list"
for devices in (None, (), [], [[]]):
with self.assertRaisesRegex(RuntimeError, msg):
f(devices)

msg = "unidentical number of devices"
for devices in ([[0, 1], [2]], [[0], [1, 2]]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)

msg = "shared by multiple replicas"
for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)

msg = "Duplicated device ids"
for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_tensor_grouping_multi_gpu(self):
from torch.nn.parallel.replicate import _group_by_device

a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(1)
d = torch.Tensor(4).cuda(0)
e = torch.Tensor(5).cuda(1)

tensors = [a, b, c, d, e]
for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]):
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, devices)

self.assertEqual(grouped_tensors, [[a, b, d], [c, e]])
self.assertEqual(grouped_devices, [[0, 2], [1, 3]])
self.assertEqual(original_index, [[0, 1, 3], [2, 4]])

msg = "missing from devices"
for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]):
with self.assertRaisesRegex(AssertionError, msg):
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, devices)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_replicate_tensor_grouping(self):
from torch.nn.parallel.replicate import _group_by_device

a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(0)

tensors = [a, b, c]

grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, [0, 1])

self.assertEqual(grouped_tensors, [[a, b, c]])
self.assertEqual(grouped_devices, [[0, 1]])
self.assertEqual(original_index, [[0, 1, 2]])

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_reshape(self):
from torch.nn.parallel.replicate import _broadcast_coalesced_reshape

a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(1)
d = torch.Tensor(4).cuda(0)
e = torch.Tensor(5).cuda(1)

tensors = [a, b, c, d, e]
outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]])

self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0], [a, b, c, d, e])
self.assertEqual(
outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)])

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_buffers(self):
net = nn.Module()
net.bn = nn.BatchNorm2d(10)
net.cuda()
replicas = dp.replicate(net, (0, 1))
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
for devices in [(0, 1), [[0], [1]]]:
replicas = dp.replicate(net, devices)
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
Expand Down
Loading

0 comments on commit 7ae0263

Please sign in to comment.