Skip to content

Commit

Permalink
[MPS] Fix crashes in several backward ops (pytorch#94343)
Browse files Browse the repository at this point in the history
This should fix the hard crashes in several backward-pass ops for sigmoid, tanh, masked_fill, linear, prelu, etc.
The tests cases that this patch fixes are part of a bigger change in TestConsistency and will be upstreamed as a separate PR.

Pull Request resolved: pytorch#94343
Approved by: https://github.com/kulinseth, https://github.com/malfet
  • Loading branch information
razarmehr authored and pytorchmergebot committed Feb 8, 2023
1 parent 61ecaf1 commit 877482e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ Tensor relu_mps(const Tensor& self) {
using namespace mps;
TORCH_CHECK(grad_input.is_mps());

if (grad_output.numel() == 0) {
return;
}
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
Expand Down Expand Up @@ -496,6 +499,9 @@ Tensor relu_mps(const Tensor& self) {
using namespace mps;
TORCH_CHECK(grad_input.is_mps());

if (grad_output.numel() == 0) {
return;
}
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
Expand Down Expand Up @@ -1686,6 +1692,9 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {

Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous);
if (grad_output.numel() == 0) {
return std::tuple<Tensor, Tensor>{grad_input, weight_grad};
}

struct CachedGraph : public MPSCachedGraph
{
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,10 @@ Tensor index_select_mps(const Tensor & self,

Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value) {
using namespace mps;

if (self.numel() == 0) {
return self;
}
TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ",
mask.device(), " and self on ", self.device());
TORCH_CHECK(mask.scalar_type() == kByte || mask.scalar_type() == kBool,
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/mps/operations/Linear.mm
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ Tensor _mps_linear_backward_input(
c10::nullopt,
grad_output.suggest_memory_format());
TORCH_CHECK(output.is_mps());
if (grad_output.numel() == 0) {
return output;
}

MPSGraphCache *cache_ = MPSGraphCache::getInstance();

Expand Down Expand Up @@ -259,6 +262,11 @@ Tensor _mps_linear_backward_input(
TORCH_CHECK(output.is_mps());
TORCH_CHECK(bias.is_mps());

if (grad_output.numel() == 0) {
output.zero_();
bias.zero_();
return std::tuple<Tensor, Tensor>{ output, bias };
}
MPSGraphCache *cache_ = MPSGraphCache::getInstance();

MPSStream *stream= getCurrentMPSStream();
Expand Down

0 comments on commit 877482e

Please sign in to comment.