-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Op lowering for Einsum #3843
Op lowering for Einsum #3843
Conversation
FWIW, under the reviewer name there is a link that switch your PR to draft. |
6ea496a
to
78319cb
Compare
86e454e
to
4cb81ab
Compare
torch_xla/csrc/aten_xla_type.cpp
Outdated
XLA_FN_COUNTER("xla::"); | ||
// Einsum operations with more than 2 operands, like bilinear operations, are | ||
// not currently supported in XLA | ||
if (tensors.size() > 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh This is a bit tricky, we want to overwrite einsum
but we can only support a certain type of einsum. Looking at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.cpp#L31 it seems like we actually need to let pt/xla to take care falling back for both forward and backward. However there is no such thing as einsum_backward
, what are we going to do with backward case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on a second thought we should not fallback but call at::native
function like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L3086 which will redecompose einsum into smaller ops. I think this should solve the backward issue for unsupported einsum too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think you want to fall back by calling into at::native::einsum
. Falling back to CPU seems like a big pessimisation; instead, you just want to fall back to the existing decomposition in core, and run XLA on the decomposed ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also... think that should fix the requires_grad
issue that you're seeing, but it's worth confirming 😛.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to at::native::einsum
here, but there's still the issue of requires_grad
for cases with 1 or 2 operands. We want to use aten_autograd_ops::EinsumAutogradFunction::apply
there to leverage XLA's einsum implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, so it's working in the case where you call at::native::einsum
, but it's losing the requires_grad in the case where it calls the custom autograd function?
If it's an issue with the autograd function, what I would check first is:
(1) Did you implement the autograd function the same way that the existing ones in pytorch_xla are implemented? (e.g. max_pool2d
?
(2) If so, can you repro the issue max_pool2d
too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, this einsum autograd function is implemented the same way as the max_pool functions. However, I cannot reproduce the issue with max_pool2d.
Max pool 2d was overriden as an autograd function in #2236. Since then it looks like we've removed scripts/gen.py
, but I believe all I have to do to setup the code generation now is add einsum
to the autograd
section of xla_native_functions.yaml
. Could there be additional setup required, either in this repo or the parent repo? CC @JackCaoG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like we explicitly do anything in the max_pool implementations to forward requires_grad
to the output. Where is that happening?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I think one difference is that pytorch does not have backward formula for einsum, it will always get dispatched to smaller ops. Not sure if this will affect the behavior of require_grads
.
Right now, there's an issue with the result of einsum not requiring grad, even when the inputs require grad
I believe this may be because there is no We can also recreate this with cpp tests like |
@bdhirsh Can you help on #3843 (comment), I am actually not sure where the |
4cb81ab
to
5fdf370
Compare
@ezyang I am wondering if you have any idea regarding #3843 (comment). This is blocking one of our benchmark experiment |
I'll take a look, but what I expected is for you to override AutogradXLA and use the C++ custom autograd function api to setup your derivative. Do you have all this? |
yup. We
I was trying to look up how The issue right now is after we do |
I am able to confirm that with this pr, forward compute the correct result but does not have grad_fn
|
if I do the same thing to maxpool2d, which is another function we overwrite backward using
|
I think the problem is torch::autograd::Function doesn't support variable list.
you need to add support for this is in torch/csrc/autograd/custom_function.h |
5fdf370
to
674f2cd
Compare
After adding support for Edit: It looks like |
5b325ce
to
c998b82
Compare
output_shape = shape_one; | ||
} | ||
|
||
return output_shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we just call BuildEinsumBackward
once and then walk through the return vector and decide if we need a tuple shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, but I needed to add an InferOutputShapes
to handle the std::vector<xla::XlaOp>
output from BuildEinsumBackward
c998b82
to
e449a12
Compare
e449a12
to
36c3d10
Compare
36c3d10
to
3369eed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @steventk-g !
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla. Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients. Pull Request resolved: #84583 Approved by: https://github.com/soulitzer
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla. Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients. Pull Request resolved: #84583 Approved by: https://github.com/soulitzer
Implements op lowering for einsum and einsum backward when (1) there are at most 2 inputs and (2) there are no equations (forward or backward) that have an index in one element (input or output) which is absent from any other element. When these conditions are not met, we fall back to the at::native implementation, which will break the einsum op down into constitutive operations.
If we want to relax condition (2), then we need a change in XLA to support those kind of einsum equations. Currently, such equations lead to an INVALID_ARGUMENT status when trying to get the shape of the output. Likewise, if we want to relax condition (1) we either need a change in XLA, or a change in the upstream to break down einsums with 3 or more inputs