Skip to content

Commit

Permalink
[RELAND] Change AccumulateGrad to yield .grads that match weights' …
Browse files Browse the repository at this point in the history
…memory layout (pytorch#40129)

Summary:
pytorch#34904 was reverted because it had a misconfigured 4 GPU test that for some reason wasn't caught by external CI ([example failure](https://app.circleci.com/pipelines/github/pytorch/pytorch/181719/workflows/cfb37cd9-9a0c-4738-898b-d683934cd308/jobs/5868948/steps)).

This PR reverts the revert, and adds diffs that should repair the misconfigured test.
Pull Request resolved: pytorch#40129

Differential Revision: D22079377

Pulled By: albanD

fbshipit-source-id: 9bd2b7e0c34fdaf887497b52037cfe82cba709c1
  • Loading branch information
definitelynotmcarilli authored and facebook-github-bot committed Jun 17, 2020
1 parent 5200814 commit 1ec8ece
Show file tree
Hide file tree
Showing 16 changed files with 716 additions and 116 deletions.
52 changes: 52 additions & 0 deletions docs/source/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,58 @@ Locally disabling gradient computation

.. autoclass:: set_grad_enabled

.. _default-grad-layouts:

Default gradient layouts
^^^^^^^^^^^^^^^^^^^^^^^^

When a non-sparse ``param`` receives a non-sparse gradient during
:func:`torch.autograd.backward` or :func:`torch.Tensor.backward`
``param.grad`` is accumulated as follows.

If ``param.grad`` is initially ``None``:

1. If ``param``'s memory is non-overlapping and dense, ``.grad`` is
created with strides matching ``param`` (thus matching ``param``'s
layout).
2. Otherwise, ``.grad`` is created with rowmajor-contiguous strides.

If ``param`` already has a non-sparse ``.grad`` attribute:

3. If ``create_graph=False``, ``backward()`` accumulates into ``.grad``
in-place, which preserves its strides.
4. If ``create_graph=True``, ``backward()`` replaces ``.grad`` with a
new tensor ``.grad + new grad``, which attempts (but does not guarantee)
matching the preexisting ``.grad``'s strides.

The default behavior (letting ``.grad``\ s be ``None`` before the first
``backward()``, such that their layout is created according to 1. or 2.,
and retained over time according to 3. or 4) is recommended for best performance.
Calls to ``model.zero_grad()`` or ``optimizer.zero_grad()`` will not affect ``.grad``
layouts.

In fact, resetting all ``.grad``\ s to ``None`` before each
accumulation phase, e.g.::

for iterations...
...
for param in model.parameters():
param.grad = None
loss.backward()

such that they're recreated according to 1. or 2. every time,
is a valid alternative to ``model.zero_grad()`` or ``optimizer.zero_grad()``
that may improve performance for some networks.

Manual gradient layouts
-----------------------

If you need manual control over ``.grad``'s strides,
assign ``param.grad =`` a zeroed tensor with desired strides
before the first ``backward()``, and never reset it to ``None``.
3. guarantees your layout is preserved as long as ``create_graph=False``.
4. indicates your layout is *likely* preserved even if ``create_graph=True``.

In-place operations on Tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
175 changes: 163 additions & 12 deletions test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datetime import timedelta
from sys import platform

from itertools import groupby
from itertools import groupby, product
from functools import reduce
import operator

Expand Down Expand Up @@ -1367,7 +1367,6 @@ def test_allgather_coalesced_checks(self):
"Invalid function argument.*output_tensor_lists"):
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)


def test_reduce_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
Expand Down Expand Up @@ -1888,6 +1887,32 @@ def forward(self, x):
return F.softmax(x, dim=1).to(dev0)


class ConvNet(nn.Module):
def __init__(self, gpus, layouts, dtypes):
super(ConvNet, self).__init__()
self.dtypes = dtypes
if isinstance(gpus, list):
self.layer_gpus = gpus
else:
gpus = [gpus] * 4
self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(device=gpus[0], memory_format=layouts[0], dtype=dtypes[0])
self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(device=gpus[1], memory_format=layouts[1], dtype=dtypes[1])
self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(device=gpus[2], memory_format=layouts[2], dtype=dtypes[2])
self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(device=gpus[3], memory_format=layouts[3], dtype=dtypes[3])

def forward(self, x):
x = x.to(self.dtypes[0])
# Could say
# x = self.conv0(x).to(device=self.conv1.weight.device, dtype=self.dtypes[1])
# etc. But I don't want to appeal to the weights' devices directly, because part of this test's purpose
# is to verify weights are where expected if the model gets replicated.
gpus = self.layer_gpus if hasattr(self, "layer_gpus") else [x.device] * 4
x = self.conv0(x).to(device=gpus[1], dtype=self.dtypes[1])
x = self.conv1(x).to(device=gpus[2], dtype=self.dtypes[2])
x = self.conv2(x).to(device=gpus[3], dtype=self.dtypes[3])
return self.conv3(x)


@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment")
class DistributedDataParallelTest(MultiProcessTestCase):
def setUp(self):
Expand Down Expand Up @@ -2342,7 +2367,6 @@ def run_and_verify_grad(model):
self.assertIsNotNone(t1_p.grad)
self.assertIsNone(task_unused_p.grad)


store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)

Expand Down Expand Up @@ -2552,15 +2576,15 @@ def step_model(model, input, target):
# Skip gradients sync without calling prepare_for_backward
step_model(
ddp_model.module,
input[self.rank : (self.rank + 1)],
target[self.rank : (self.rank + 1)])
input[self.rank:(self.rank + 1)],
target[self.rank:(self.rank + 1)])
for i, j in zip(model.parameters(), ddp_model.parameters()):
self.assertNotEqual(i.grad, j.grad)
else:
step_model(
ddp_model,
input[self.rank : (self.rank + 1)],
target[self.rank : (self.rank + 1)])
input[self.rank:(self.rank + 1)],
target[self.rank:(self.rank + 1)])
for i, j in zip(model.parameters(), ddp_model.parameters()):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(i.grad, j.grad)
Expand Down Expand Up @@ -2762,6 +2786,132 @@ def forward(self, x):
ddp_parameter = next(ddp_model.parameters())
self.assertEqual(vanilla_parameter.grad, ddp_parameter.grad)

def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

global_batch_size = local_batch_size * self.world_size

# Carry out some trials with small buckets and some with big buckets.
bucketsizes = (0.000001, 25)
# Tuples of lists. Each list describes per-layer characteristics for one trial.
layer_formats = ([torch.contiguous_format] * 4,
[torch.channels_last] * 2 + [torch.contiguous_format] * 2,
[torch.channels_last] * 4)
layer_dtypes = ([torch.float] * 4,
[torch.float] * 2 + [torch.half] * 2,
[torch.half] * 4)

input_dev = layer_devs[0] if isinstance(layer_devs, list) else layer_devs
target_dev = layer_devs[-1] if isinstance(layer_devs, list) else layer_devs
input = torch.randn(global_batch_size, 8, 8, 8).to(input_dev)
target = torch.randn(global_batch_size, 8, 4, 4).to(target_dev)
local_batch_start = self.rank * local_batch_size
local_batch_end = (self.rank + 1) * local_batch_size

with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False):
for formats, dtypes, bucketsize in product(layer_formats, layer_dtypes, bucketsizes):
model_msg = "rank = {} formats = {} dtypes = {} bucketsize = {} ".format(self.rank, formats,
dtypes, bucketsize)
try:
m = ConvNet(layer_devs, formats, dtypes)
m_ddp = DistributedDataParallel(copy.deepcopy(m),
device_ids=replica_devices,
process_group=process_group,
bucket_cap_mb=bucketsize)
has_half = any(p.dtype is torch.half for p in m.parameters())
tol = 1.e-3 if has_half else 1.e-5
except BaseException:
# Prints case-specific debugging info to narrow down failing case.
print("Caught exception during model creation for " + model_msg, flush=True)
raise
# 3 iters: First iter creates grads, second iter retests after rebucketing,
# third iter tries zeroed grads.
for it in range(3):
iter_msg = "iter = {} ".format(it) + model_msg
try:
F.mse_loss(m(input).float(), target).backward()
F.mse_loss(m_ddp(input[local_batch_start: local_batch_end]).float(),
target[local_batch_start: local_batch_end]).backward()
for i, ((layer_name, m_child), m_ddp_child) in enumerate(zip(m.named_children(),
m_ddp.module.children())):
named_msg = layer_name + ".weight" + " " + iter_msg
self.assertTrue(m_child.weight.grad.is_contiguous(memory_format=formats[i]), named_msg)
self.assertTrue(m_ddp_child.weight.grad.is_contiguous(memory_format=formats[i]), named_msg)
for j, ((param_name, p), p_ddp) in enumerate(zip(m_child.named_parameters(),
m_ddp_child.parameters())):
named_msg = layer_name + "." + param_name + " " + iter_msg
self.assertEqual(p.grad, p_ddp.grad, msg=named_msg, rtol=tol, atol=tol)
if it == 0:
p.grad = None
p_ddp.grad = None
else:
m.zero_grad()
m_ddp.zero_grad()
except BaseException:
# Makes sure we still get info if an error occurred somewhere other than the asserts.
print("Caught exception during iterations at " + iter_msg, flush=True)
raise

@requires_nccl()
@skip_if_not_multigpu
@skip_if_rocm
def test_grad_layout_1devicemodule_1replicaperprocess(self):
dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
# Tells DDP to use just one device.
replica_devices = [dev0]
# Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
layer_devs = dev0
local_batch_size = 8
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)

@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_if_rocm
def test_grad_layout_1devicemodule_2replicaperprocess(self):
int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
dev0 = torch.device("cuda:" + str(int_devices[0]))
dev1 = torch.device("cuda:" + str(int_devices[1]))
# Tells DDP to replicate the model to both of this process's devices.
replica_devices = [dev0, dev1]
# Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
layer_devs = dev0
local_batch_size = 16
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)

@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_if_rocm
def test_grad_layout_2devicemodule(self):
int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
dev0 = torch.device("cuda:" + str(int_devices[0]))
dev1 = torch.device("cuda:" + str(int_devices[1]))
# DDP's default behavior for a multi-device module is "don't replicate."
replica_devices = None
# Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
layer_devs = [dev0] * 2 + [dev1] * 2
local_batch_size = 8
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)

@requires_nccl()
@skip_if_not_multigpu
@skip_if_rocm
def test_param_layout_mismatch_error(self):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
layer_devs = dev0
layer_formats = [torch.contiguous_format] * 4 if self.rank == 0 else [torch.channels_last] * 4
layer_dtypes = [torch.float] * 4

m = ConvNet(layer_devs, layer_formats, layer_dtypes)
if self.rank == 0:
m_ddp = DistributedDataParallel(m, device_ids=[dev0], process_group=process_group)
else:
with self.assertRaisesRegex(RuntimeError, ".* appears not to match strides of the same param in process 0"):
m_ddp = DistributedDataParallel(m, device_ids=[dev0], process_group=process_group)


class ReducerModule(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -2943,6 +3093,7 @@ def test_multi_limit_multi_dtype(self):
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)


@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment")
class NcclErrorHandlingTest(MultiProcessTestCase):
def setUp(self):
Expand Down Expand Up @@ -3028,31 +3179,31 @@ def _test_nccl_errors_blocking(self, func):
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_errors_blocking_clean_exit(self):
self._test_nccl_errors_blocking(lambda : sys.exit(0))
self._test_nccl_errors_blocking(lambda: sys.exit(0))

@requires_nccl()
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_errors_blocking_nonzero_exit(self):
self._test_nccl_errors_blocking(lambda : sys.exit(1))
self._test_nccl_errors_blocking(lambda: sys.exit(1))

@requires_nccl()
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda : os.abort())
self._test_nccl_errors_blocking(lambda: os.abort())

@requires_nccl()
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_errors_blocking_sigkill(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGKILL))
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))

@requires_nccl()
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_errors_blocking_sigterm(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM))
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))

def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["NCCL_BLOCKING_WAIT"] = val
Expand Down
Loading

0 comments on commit 1ec8ece

Please sign in to comment.