Skip to content

Commit

Permalink
Allow torch.cuda.amp.GradScaler to support sparse gradients (pytorch#…
Browse files Browse the repository at this point in the history
…36786)

Summary:
Should close pytorch#35810.

I decided to keep sparse handling on the Python side for clarity, although it could be moved to the C++ side (into `_amp_non_finite_check_and_unscale_`) without much trouble.

For non-fp16 sparse grads the logic is simple (call `_amp_non_finite_check_and_unscale_` on `grad._values()`) instead of `grad` itself.  At least I hope it's that easy.

For fp16 sparse grads, it's tricker.  Sparse tensors can be uncoalesced.  From the [Note](https://pytorch.org/docs/master/sparse.html#torch.sparse.FloatTensor):
> Our sparse tensor format permits uncoalesced sparse tensors, where there may be duplicate coordinates in the indices; in this case, the interpretation is that the value at that index is the sum of all duplicate value entries.

An uncoalesced scaled fp16 grad may have values at duplicate coordinates that are all finite but large, such that adding them to make the coalesced version WOULD cause overflows.**  If I checked `_values()` on the uncoalesced version, it might not report overflows, but I think it should.

So, if the grad is sparse, fp16, and uncoalesced, I still call `_amp_non_finite_check_and_unscale_` to unscale `grad._values()` in-place, but I also double-check the coalesced version by calling a second `_amp_non_finite_check_and_unscale_` on `grad.coalesce()._values()`.  `coalesce()` is out-of-place, so this call doesn't redundantly affect `grad._values()`, but it does have the power to populate the same `found_inf` tensor.  The `is_coalesced()` check and `coalesce()` probably aren't great for performance, but if someone needs a giant embedding table in FP16, they're better than nothing and memorywise, they'll only create a copy of nnz gradient values+indices, which is still way better than changing the whole table to FP32.

An `unscale` variant with liberty to create unscaled grads out-of-place, and replace `param.grad` instead of writing through it, could get away with just one `_amp_non_finite_check_and_unscale_`.  It could say `coalesced = grad.coalesced()`, do only the stronger `_amp_non_finite_check_and_unscale_` on `coalesced._values()`, and set `param.grad = coalesced`.  I could even avoid replacing `param.grad` itself by going one level deeper and setting `param.grad`'s indices and values to `coalesced`'s, but that seems brittle and still isn't truly "in place".

** you could whiteboard an uncoalesced fp32 grad with the same property, but fp32's range is big enough that I don't think it's realistic.
Pull Request resolved: pytorch#36786

Reviewed By: ezyang

Differential Revision: D22202832

Pulled By: ngimel

fbshipit-source-id: b70961a4b6fc3a4c1882f65e7f34874066435735
  • Loading branch information
definitelynotmcarilli authored and facebook-github-bot committed Jun 24, 2020
1 parent d855528 commit b4ccdef
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
60 changes: 60 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,66 @@ def test_grad_scaling_builtins(self, device="cuda", dtype=torch.float):
self.assertEqual(growth_tracker, 0)
self.assertEqual(scale, 2.0)

def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float):
scaler = torch.cuda.amp.GradScaler()

inv_scale = torch.tensor([0.25], dtype=dtype, device=device)
found_inf = torch.empty((1,), dtype=dtype, device=device)
cur = found_inf.device

# As of d0c925f (4/16/20), docs are unclear about best API for sparse cuda tensor construction.
# https://pytorch.org/docs/master/tensors.html shows torch.sparse_coo_tensor(...), but it has no docstring.
# The same page shows several tensors with layout=torch.sparse_coo, but no constructors using that layout.
# Meanwhile, https://pytorch.org/docs/master/sparse.html shows torch.sparse.FloatTensor(...), which looks
# legacy and does not accept a device="cuda" kwarg. Going with torch.sparse_coo_tensor.
i = torch.tensor([[0, 1, 1],
[2, 0, 2]], device="cuda", dtype=torch.int64)
v = torch.tensor([16., 32., 64.], device="cuda", dtype=torch.float)
s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cuda", dtype=dtype)

p = s.clone()
assert p.is_sparse
opt = torch.optim.SGD([p], lr=1.)

p.grad = s.clone()
found_inf.zero_()
found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
self.assertEqual(found_inf, 0.0)
self.assertTrue(torch.allclose(p.grad.to_dense(), (s / 4).to_dense()))

v = torch.FloatTensor([16., 32., float('inf')])
p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cuda", dtype=dtype)
found_inf.zero_()
found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
self.assertEqual(found_inf, 1.0)

v = torch.FloatTensor([16., 32., float('nan')])
p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cuda", dtype=dtype)
found_inf.zero_()
found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
self.assertEqual(found_inf, 1.0)

p = s.clone().half()
assert p.is_sparse
opt = torch.optim.SGD([p], lr=1.)

p.grad = s.clone().half()
found_inf.zero_()
found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
self.assertEqual(found_inf, 0.0)
self.assertTrue(torch.allclose(p.grad.to_dense(), (s.half() / 4).to_dense()))

# Creates fp16 sparse tensor with duplicated indices (uncoalesced). The uncoalesced representation
# does not overflow in fp16, but the coalesced representation would, because 64000 + 64000 > fp16 max.
# _amp_non_finite_check_and_unscale_ should report an overflow here.
i = torch.LongTensor([[0, 1, 0],
[2, 0, 2]])
v = torch.FloatTensor([64000., 32., 64000.])
p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cuda", dtype=torch.float16)
found_inf.zero_()
found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
self.assertEqual(found_inf, 1.0)

@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_grad_scaling_device_as_key(self):
# Ensure that different instances of "device" objects that point to the same device
Expand Down
21 changes: 18 additions & 3 deletions torch/cuda/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,21 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
else:
torch._amp_non_finite_check_and_unscale_(param.grad,
per_device_found_inf.get(param.grad.device),
per_device_inv_scale.get(param.grad.device))
with torch.no_grad():
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad

torch._amp_non_finite_check_and_unscale_(to_unscale,
per_device_found_inf.get(param.grad.device),
per_device_inv_scale.get(param.grad.device))

return per_device_found_inf._per_device_tensors

Expand Down Expand Up @@ -220,6 +232,9 @@ def unscale_(self, optimizer):
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
Expand Down

0 comments on commit b4ccdef

Please sign in to comment.