Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow torch.cuda.amp.GradScaler to support sparse gradients (pytorch#…
…36786) Summary: Should close pytorch#35810. I decided to keep sparse handling on the Python side for clarity, although it could be moved to the C++ side (into `_amp_non_finite_check_and_unscale_`) without much trouble. For non-fp16 sparse grads the logic is simple (call `_amp_non_finite_check_and_unscale_` on `grad._values()`) instead of `grad` itself. At least I hope it's that easy. For fp16 sparse grads, it's tricker. Sparse tensors can be uncoalesced. From the [Note](https://pytorch.org/docs/master/sparse.html#torch.sparse.FloatTensor): > Our sparse tensor format permits uncoalesced sparse tensors, where there may be duplicate coordinates in the indices; in this case, the interpretation is that the value at that index is the sum of all duplicate value entries. An uncoalesced scaled fp16 grad may have values at duplicate coordinates that are all finite but large, such that adding them to make the coalesced version WOULD cause overflows.** If I checked `_values()` on the uncoalesced version, it might not report overflows, but I think it should. So, if the grad is sparse, fp16, and uncoalesced, I still call `_amp_non_finite_check_and_unscale_` to unscale `grad._values()` in-place, but I also double-check the coalesced version by calling a second `_amp_non_finite_check_and_unscale_` on `grad.coalesce()._values()`. `coalesce()` is out-of-place, so this call doesn't redundantly affect `grad._values()`, but it does have the power to populate the same `found_inf` tensor. The `is_coalesced()` check and `coalesce()` probably aren't great for performance, but if someone needs a giant embedding table in FP16, they're better than nothing and memorywise, they'll only create a copy of nnz gradient values+indices, which is still way better than changing the whole table to FP32. An `unscale` variant with liberty to create unscaled grads out-of-place, and replace `param.grad` instead of writing through it, could get away with just one `_amp_non_finite_check_and_unscale_`. It could say `coalesced = grad.coalesced()`, do only the stronger `_amp_non_finite_check_and_unscale_` on `coalesced._values()`, and set `param.grad = coalesced`. I could even avoid replacing `param.grad` itself by going one level deeper and setting `param.grad`'s indices and values to `coalesced`'s, but that seems brittle and still isn't truly "in place". ** you could whiteboard an uncoalesced fp32 grad with the same property, but fp32's range is big enough that I don't think it's realistic. Pull Request resolved: pytorch#36786 Reviewed By: ezyang Differential Revision: D22202832 Pulled By: ngimel fbshipit-source-id: b70961a4b6fc3a4c1882f65e7f34874066435735
- Loading branch information