Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Op lowering for Einsum #3843

Merged
merged 2 commits into from
Sep 16, 2022
Merged

Op lowering for Einsum #3843

merged 2 commits into from
Sep 16, 2022

Conversation

steventk-g
Copy link
Collaborator

@steventk-g steventk-g commented Aug 8, 2022

Implements op lowering for einsum and einsum backward when (1) there are at most 2 inputs and (2) there are no equations (forward or backward) that have an index in one element (input or output) which is absent from any other element. When these conditions are not met, we fall back to the at::native implementation, which will break the einsum op down into constitutive operations.

If we want to relax condition (2), then we need a change in XLA to support those kind of einsum equations. Currently, such equations lead to an INVALID_ARGUMENT status when trying to get the shape of the output. Likewise, if we want to relax condition (1) we either need a change in XLA, or a change in the upstream to break down einsums with 3 or more inputs

xla_native_functions.yaml Outdated Show resolved Hide resolved
@miladm
Copy link
Collaborator

miladm commented Aug 15, 2022

FWIW, under the reviewer name there is a link that switch your PR to draft.

@steventk-g steventk-g marked this pull request as draft August 15, 2022 15:18
@steventk-g steventk-g force-pushed the einsum-op-lowering branch 6 times, most recently from 6ea496a to 78319cb Compare August 19, 2022 21:33
@steventk-g steventk-g force-pushed the einsum-op-lowering branch 6 times, most recently from 86e454e to 4cb81ab Compare August 24, 2022 20:01
XLA_FN_COUNTER("xla::");
// Einsum operations with more than 2 operands, like bilinear operations, are
// not currently supported in XLA
if (tensors.size() > 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh This is a bit tricky, we want to overwrite einsum but we can only support a certain type of einsum. Looking at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.cpp#L31 it seems like we actually need to let pt/xla to take care falling back for both forward and backward. However there is no such thing as einsum_backward, what are we going to do with backward case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on a second thought we should not fallback but call at::native function like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L3086 which will redecompose einsum into smaller ops. I think this should solve the backward issue for unsupported einsum too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think you want to fall back by calling into at::native::einsum. Falling back to CPU seems like a big pessimisation; instead, you just want to fall back to the existing decomposition in core, and run XLA on the decomposed ops.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also... think that should fix the requires_grad issue that you're seeing, but it's worth confirming 😛.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to at::native::einsum here, but there's still the issue of requires_grad for cases with 1 or 2 operands. We want to use aten_autograd_ops::EinsumAutogradFunction::apply there to leverage XLA's einsum implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so it's working in the case where you call at::native::einsum, but it's losing the requires_grad in the case where it calls the custom autograd function?

If it's an issue with the autograd function, what I would check first is:
(1) Did you implement the autograd function the same way that the existing ones in pytorch_xla are implemented? (e.g. max_pool2d?
(2) If so, can you repro the issue max_pool2d too?

Copy link
Collaborator Author

@steventk-g steventk-g Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, this einsum autograd function is implemented the same way as the max_pool functions. However, I cannot reproduce the issue with max_pool2d.

Max pool 2d was overriden as an autograd function in #2236. Since then it looks like we've removed scripts/gen.py, but I believe all I have to do to setup the code generation now is add einsum to the autograd section of xla_native_functions.yaml. Could there be additional setup required, either in this repo or the parent repo? CC @JackCaoG

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like we explicitly do anything in the max_pool implementations to forward requires_grad to the output. Where is that happening?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I think one difference is that pytorch does not have backward formula for einsum, it will always get dispatched to smaller ops. Not sure if this will affect the behavior of require_grads.

@steventk-g
Copy link
Collaborator Author

steventk-g commented Aug 25, 2022

Right now, there's an issue with the result of einsum not requiring grad, even when the inputs require grad

>>> x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
>>> y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)
>>> z = torch.einsum('i,j->ij', x, y)
>>> z.requires_grad
True
>>> x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=xm.xla_device())
>>> y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=xm.xla_device())
>>> z = torch.einsum('i,j->ij', x, y)
>>> z.requires_grad
False
>>> 

I believe this may be because there is no einsum_backward in the upstream, so pytorch may not know how to set the grad function on the result tensor. Right now, we define einsum_backward in torch_xla/csrc/ops/xla_ops.cpp

We can also recreate this with cpp tests like AtenXlaTensorTest::TestEinsumOuterBackward

@JackCaoG
Copy link
Collaborator

@bdhirsh Can you help on #3843 (comment), I am actually not sure where the grad inheritance logic coming from.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 2, 2022

@ezyang I am wondering if you have any idea regarding #3843 (comment). This is blocking one of our benchmark experiment

@ezyang
Copy link
Collaborator

ezyang commented Sep 2, 2022

I'll take a look, but what I expected is for you to override AutogradXLA and use the C++ custom autograd function api to setup your derivative. Do you have all this?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 2, 2022

I'll take a look, but what I expected is for you to override AutogradXLA and use the C++ custom autograd function api to setup your derivative. Do you have all this?

yup. We

  1. put einsum under autograd section of the xla_native_functions.yaml
  2. use aten_autograd_ops::EinsumAutogradFunction::apply under aten_xla_type.cpp
  3. define EinsumAutogradFunction with both forward and backward function which create the corresponding XLATensor(IR)

I was trying to look up how requires_grad is being set but I can't find any reference in pytorch/xla repo so I was suspecting this is a layer above us.

The issue right now is after we do torch.einsum (aten_autograd_ops::EinsumAutogradFunction::forward is called), result tensor's requires_grad is always false. I tried to compare this with existing maxpool2d which we also overwrite the backward but can't find any difference.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 2, 2022

I am able to confirm that with this pr, forward compute the correct result but does not have grad_fn

>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>> x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=xm.xla_device())
>>> y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=xm.xla_device())
>>> z = torch.einsum('i,j->ij', x, y)
torch::Tensor EinsumAutogradFunction::forward called
>>> z.requires_grad
False
>>> x.requires_grad
True
>>> y.requires_grad
True
>>> z
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]], device='xla:0')
>>> torch.einsum('i,j->ij', x.cpu(), y.cpu())
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]], grad_fn=<MulBackward0>)

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 2, 2022

if I do the same thing to maxpool2d, which is another function we overwrite backward using autograd key, I saw

>>> import torch.nn as nn
>>> m = nn.MaxPool2d(3, stride=2)
>>> input = torch.randn(20, 16, 50, 32, requires_grad=True, device=xm.xla_device())
>>> m(input)

          ....
          ...,
          [ 2.0746,  1.7297,  1.5253,  ...,  1.7077,  2.0141,  2.0141],
          [ 1.2569,  1.5253,  1.5253,  ...,  1.1359,  1.1269,  1.6523],
          [ 0.5896,  1.4508,  1.2428,  ...,  1.9659,  1.9659,  1.6523]]]],
       device='xla:0', grad_fn=<MaxPool2dAutogradFunction>>)

@ezyang
Copy link
Collaborator

ezyang commented Sep 3, 2022

I think the problem is torch::autograd::Function doesn't support variable list.

struct ExtractVariables : IterArgs<ExtractVariables> {
  std::vector<bool>& is_var_;
  variable_list& list_;
  ExtractVariables(std::vector<bool>& is_var, variable_list& list)
      : is_var_(is_var), list_(list) {}
  void operator()(const c10::optional<at::Tensor>& x) {
    // NOLINTNEXTLINE(bugprone-branch-clone)
    if (x.has_value() && x.value().defined()) {
      is_var_.push_back(true);
      list_.emplace_back(x.value());
    } else {
      is_var_.push_back(false);
    }
  }
  void operator()(const at::Tensor& x) {
    is_var_.push_back(true);
    list_.emplace_back(x);
  }
  template <typename T>
  void operator()(const T& x) {
    is_var_.push_back(false);
  }
};

you need to add support for variable_list to the operator() overloads

this is in torch/csrc/autograd/custom_function.h

@steventk-g
Copy link
Collaborator Author

steventk-g commented Sep 6, 2022

After adding support for variable_list to operator() in ExtractVariables (pytorch/pytorch#84583), I'm still seeing the same problem.

Edit:

It looks like void operator()(const T& x) is still being called, even after adding the variable_list method above.

output_shape = shape_one;
}

return output_shape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we just call BuildEinsumBackward once and then walk through the return vector and decide if we need a tuple shape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but I needed to add an InferOutputShapes to handle the std::vector<xla::XlaOp> output from BuildEinsumBackward

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @steventk-g !

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 16, 2022
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla.  Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients.
Pull Request resolved: #84583
Approved by: https://github.com/soulitzer
@JackCaoG JackCaoG merged commit aa73509 into master Sep 16, 2022
mehtanirav pushed a commit to pytorch/pytorch that referenced this pull request Oct 4, 2022
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla.  Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients.
Pull Request resolved: #84583
Approved by: https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants