Skip to content

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Sep 16, 2025

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?


📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

@aseyboldt aseyboldt force-pushed the as-jax-opt2 branch 3 times, most recently from 10dfa2e to ead4ac7 Compare September 16, 2025 15:08
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements the as_jax_op decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.

  • Implements JAXOp class for wrapping JAX functions as PyTensor operations
  • Creates as_jax_op decorator for easy conversion of JAX functions to PyTensor-compatible operations
  • Adds comprehensive test coverage for various input/output patterns and data types

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pytensor/link/jax/ops.py Core implementation of JAXOp class and as_jax_op decorator
tests/link/jax/test_as_jax_op.py Comprehensive test suite covering various use cases and data types
pytensor/init.py Exports as_jax_op function with fallback for missing dependencies
pytensor/link/jax/dispatch/basic.py JAX dispatch registration for the new JAXOp
doc/library/index.rst Documentation entry for the new functionality
doc/environment.yml Updates documentation environment to include JAX dependencies
doc/conf.py Adds Equinox to intersphinx mapping
.github/workflows/test.yml Updates CI to install Equinox dependency

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link

codecov bot commented Sep 16, 2025

Codecov Report

❌ Patch coverage is 89.44444% with 19 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.66%. Comparing base (1dc982c) to head (fceed2c).
⚠️ Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/jax/ops.py 88.95% 10 Missing and 9 partials ⚠️

❌ Your patch check has failed because the patch coverage (89.44%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1614      +/-   ##
==========================================
- Coverage   81.69%   81.66%   -0.04%     
==========================================
  Files         230      232       +2     
  Lines       52950    53177     +227     
  Branches     9404     9417      +13     
==========================================
+ Hits        43260    43428     +168     
- Misses       7256     7292      +36     
- Partials     2434     2457      +23     
Files with missing lines Coverage Δ
pytensor/compile/ops.py 83.91% <100.00%> (+0.46%) ⬆️
pytensor/link/jax/dispatch/basic.py 83.52% <100.00%> (+0.81%) ⬆️
pytensor/link/jax/ops.py 88.95% <88.95%> (ø)

... and 37 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if any(s is None for s in shape):
_, shape = pt.basic.infer_static_shape(var.shape)
if any(s is None for s in shape):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism

PyTensor only needs to know the dtype and rank of the outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail or does it infer they must match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails:

u, v = export.symbolic_shape("u, v")

x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)

export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).

@aseyboldt
Copy link
Member Author

I kept most of what was in the original PR, but made a few changes:

  • There is no longer a different op for the gradient. That is just again a JaxOp
  • I kept support for jax tree inputs and outputs, I think those are quite valuable. For instance when we have a neural network in a model, or if we want to solve an ODE, it is much nicer if we don't have to take apart all jax trees everywhere by hand. I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with jax.infer_shape, but that still needs static shape info.

I'm not sure right now what the best way to handle this is.

@jdehning
Copy link

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?

📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

Yes, sure. Sorry that I dropped the ball.

@jdehning
Copy link

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values

@jdehning
Copy link

I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function.

@aseyboldt
Copy link
Member Author

Yes, sure. Sorry that I dropped the ball.

No worries :-)

I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR.

That works for cases where the shape is in principle known, but pytensor missed it during the static shape inference.
It does not help if the shape really is dynamic, which is the case if you use dimensions in a pymc model.

In the case where the input shapes only depend on random variables (which will usually be the case in pymc models), we Ricardo realized that we can just eval the shapes once. It should be also be fine if we find just some valid input shapes, if the shapes change later that shouldn't be a problem. We only have to set the output shapes to None in that case, because we don't actually know what the output shapes should be.

I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function.

I thought it might have been something like that. I think that makes sense, but for now I'd rather get this in as-is. We can always revisit this later if we think we need it.

@aseyboldt
Copy link
Member Author

We also should decide on a name: I think wrap_jax is maybe better than as_jax_op?
Or possibly jax_to_pytensor or jax_bridge?

@aseyboldt aseyboldt marked this pull request as ready for review September 16, 2025 18:47
)

# Create VJP operation
vjp_op = JAXOp(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTensor won't be able to fuse duplicate applications of the grad (if they somehow exist) as it will create a new function under the hood and that's used for the Op equality (whatever is on __props__ is used)

Not a blocker, just laying it out. Alternatively if the grad was an Op parametrized by the original Op/func and the connected gradients it could be merged.

However, if running on jax backed, JAX itself may be able to avoid duplication.

I would err on keeping it simple for now like you did

[],
resolved_input_shapes,
on_unused_input="ignore",
mode="FAST_COMPILE",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mode="FAST_COMPILE",
mode=Mode(linker="py", optimizer="fast_compile"),

No reason for C code I guess

@ricardoV94
Copy link
Member

I like as_jax_op in line with as_op?

They do the same except in this case we also know how to generate gradients and dispatch to jax?

@ricardoV94
Copy link
Member

Alternatively, jax_as_op?

@aseyboldt
Copy link
Member Author

as_jax_op sounds like it should do the opposite of what it does, export something to jax. jax_as_op solves that problem...

I don't think Op is a good name to put into the public interface that much. Most people who use pymc won't even know what we mean by it. The function also doesn't return an op, it just happens to create one internally, and doesn't even show that to the user.

I see the argument of it being like as_op, but I also don't quite like that name for the same reasons, and I don't think many people are using it in the first place.

Other options: pytensorize_jax or jax_to_pytensor?

I think something that describes what it does, instead of how it does it would be much friendlier?

@jdehning
Copy link

+1 for jax_to_pytensor. It is for me the most easily understandable

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comments, besides that we should decide on the name. Don't love jax/ops.py filename either.

Comment on lines 117 to 118
for i, result in enumerate(results):
outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small opt

Suggested change
for i, result in enumerate(results):
outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype)
for out_container, result, out_type in zip(outputs, results, self.output_types):
out_container[0] = np.array(result, dtype=out_type.dtype)


def make_node(self, *inputs: Variable) -> Apply:
"""Create an Apply node with the given inputs and inferred outputs."""
outputs = [output_type() for output_type in self.output_types]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert the number of inputs is as expected?

Comment on lines 95 to 102
def __repr__(self):
base = self.__class__.__name__
if self.name is not None:
base = f"{base}{self.name}"
props = list(self.__props__)
props.remove("name")
props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props)
return f"{base}({props})"
Copy link
Member

@ricardoV94 ricardoV94 Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stick with default __repr__ that already exists based on props?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getattr(self, prop, ?) shouldn't be needed. If a prop is missing from an Op we have bigger issues to worry about (equality/ hashing)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be good if the name of the jax function appears in the node name somewhere. Otherwise it can be tricky to figure out which jaxop is which?

>>>
>>> # Create the input and output types, input has a dynamic shape.
>>> input_type = TensorType("float32", shape=(None,))
>>> output_type = TensorType("float32", shape=(1,))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the output_type a scalar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad llm ;-)


def make_node(self, *inputs: Variable) -> Apply:
"""Create an Apply node with the given inputs and inferred outputs."""
outputs = [output_type() for output_type in self.output_types]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_node should convert the inputs to tensors. inputs = [inp_type.filter_variable(inp) for inp, inp_type in zip(inputs, self.input_types)]

Comment on lines 132 to 135
connected_output_indices = []
for i, output_grad in enumerate(output_gradients):
if not isinstance(output_grad.type, DisconnectedType):
connected_output_indices.append(i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_comp?

Comment on lines 175 to 180
gradient_outputs = vjp_op(
*[*inputs, *[output_gradients[i] for i in connected_output_indices]]
)
if not isinstance(gradient_outputs, Sequence):
gradient_outputs = [gradient_outputs]
return gradient_outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gradient_outputs = vjp_op(
*[*inputs, *[output_gradients[i] for i in connected_output_indices]]
)
if not isinstance(gradient_outputs, Sequence):
gradient_outputs = [gradient_outputs]
return gradient_outputs
return = vjp_op(
*[*inputs, *[output_gradients[i] for i in connected_output_indices]],
return_list=True,
)

return gradient_outputs


def as_jax_op(jax_function=None, *, allow_eval=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should open an issue to allow itype/otypes like as_op already does (unless you want to implement it yourself)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +228 to +230
... return {
... "sum": jnp.add(x, y) * scale,
... }
Copy link
Member

@ricardoV94 ricardoV94 Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this example still correct? We wrap things back in pytree after the fact? It seems we do.

But we should be running these as part of doctest, and we're not: https://github.com/pymc-devs/pytensor/actions/runs/17780488500/job/50538346433?pr=1614

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't run because as_jax_op is not imported by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an idea about how to fix this? I never really looked into how the doctests are executed.

Comment on lines 248 to 257
Notes
-----
The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt,
available at
`pymc-labs.io <https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick
-examples/>`__.
To accept functions and non-PyTensor variables as input, the function uses
:func:`equinox.partition` and :func:`equinox.combine` to split and combine the
variables. Shapes are inferred using
:func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

@aseyboldt
Copy link
Member Author

@ricardoV94 I (hopefully) address all comments I didn't reply to directly.

On top of that, I also removed the equinox dependency. We only needed two small functions from it, so I just copied those over with a note in the source. It's Apache 2.0, so I think that's ok? (Also, it's just a few lines anyway...).

I still like the name wrap_jax and wrap_py. Not sure what to do about this. If you prefer as_jax_op we can also stick with that, should be easy to just undo those two renaming commits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants