Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make einsum a leaf #1364

Open
ailzhang opened this issue Nov 14, 2019 · 11 comments
Open

Make einsum a leaf #1364

ailzhang opened this issue Nov 14, 2019 · 11 comments
Assignees
Labels
enhancement New feature or request nostale Do not consider for staleness

Comments

@ailzhang
Copy link
Contributor

#1225
XLA has a optimized einsum implementation that we can use. Requires a change in upstream.

@ailzhang ailzhang added the enhancement New feature or request label Nov 14, 2019
@ailzhang ailzhang self-assigned this Nov 14, 2019
@stale stale bot added the stale Has not had recent activity label Dec 14, 2019
@jysohn23 jysohn23 added the nostale Do not consider for staleness label Dec 14, 2019
@stale stale bot removed the stale Has not had recent activity label Dec 14, 2019
@pytorch pytorch deleted a comment from stale bot Dec 14, 2019
@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 8, 2022

@bdhirsh Do you think this is possible?

@ronghanghu
Copy link
Collaborator

@JackCaoG This is the issue einsum op we just mentioned. From our earlier profiling, lowering it could potentially bring 5%+ speed up to several models. However, as mentioned in this issue and #2385, it requires an upstream change in PyTorch to dispatch it.

@ronghanghu
Copy link
Collaborator

(Oh, it seems that we raced on commenting and you're already on this thread)

@JackCaoG
Copy link
Collaborator

@ezyang in case you have some insight 😄

@ezyang
Copy link
Collaborator

ezyang commented Jul 14, 2022

we need to write a backward kernel for einsum. Do you have an einsum_backward op? We could stub one in and just not have an implementation in PT proper

@JackCaoG
Copy link
Collaborator

I didn't find anything by doing a quick search, will check with xla team regarding einsum backward.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 14, 2022

So there isn't einsum_backward for XLA nativelly but after I talked with Blake I think we can implement that using the einsum. In Blake's word

it is easy enough to do in pytorch xla or in xla client
you just need to parse the comma string
and swap the operand string and output string
for the operand you want to take the derriative with respect to
so like "...a,...ab->...b" would got to ...a,..b->...ab
and and ...b,...ab->...a
to get operand 1 and operand 0 gradients respectively

@ezyang If you could make einsum and einsum_bakward as leaf nodes, I will try to lower them using xla::Einsum and test it for the pytorch/xla.

@ezyang
Copy link
Collaborator

ezyang commented Jul 21, 2022

I don't think we need to do anything in PyTorch; just add einsum to the autograd list in xla_native_functions.yaml and then implement the custom autograd function the same as the other ops. We could upstream them but this is probably easiest.

@JackCaoG
Copy link
Collaborator

Oh OK. It seems like einsum is already something we can lower https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L1892. I will try that. Thanks Ed!

@JackCaoG
Copy link
Collaborator

@steventk-g You can use https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.h#L10 as an example to write both forward and backward part for einsum. You need to put einsum under here

@steventk-g
Copy link
Collaborator

After #3843, we will need changes to support (1) einsum on more than 2 inputs and (2) einsum on equations like ijj,k->ik, where one input or output has an index that none of the other inputs or outputs have. For now, we fall back to the at::native implementation in these cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request nostale Do not consider for staleness
Projects
None yet
Development

No branches or pull requests

6 participants