-
Notifications
You must be signed in to change notification settings - Fork 139
Implement as_jax_op #1614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement as_jax_op #1614
Conversation
…he previous approach for testing purposes
…be used without the decorator @as_jax_op
10dfa2e
to
ead4ac7
Compare
ead4ac7
to
d04f41d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements the as_jax_op
decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.
- Implements
JAXOp
class for wrapping JAX functions as PyTensor operations - Creates
as_jax_op
decorator for easy conversion of JAX functions to PyTensor-compatible operations - Adds comprehensive test coverage for various input/output patterns and data types
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
pytensor/link/jax/ops.py | Core implementation of JAXOp class and as_jax_op decorator |
tests/link/jax/test_as_jax_op.py | Comprehensive test suite covering various use cases and data types |
pytensor/init.py | Exports as_jax_op function with fallback for missing dependencies |
pytensor/link/jax/dispatch/basic.py | JAX dispatch registration for the new JAXOp |
doc/library/index.rst | Documentation entry for the new functionality |
doc/environment.yml | Updates documentation environment to include JAX dependencies |
doc/conf.py | Adds Equinox to intersphinx mapping |
.github/workflows/test.yml | Updates CI to install Equinox dependency |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (89.44%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1614 +/- ##
==========================================
- Coverage 81.69% 81.66% -0.04%
==========================================
Files 230 232 +2
Lines 52950 53177 +227
Branches 9404 9417 +13
==========================================
+ Hits 43260 43428 +168
- Misses 7256 7292 +36
- Partials 2434 2457 +23
🚀 New features to boost your workflow:
|
pytensor/link/jax/ops.py
Outdated
if any(s is None for s in shape): | ||
_, shape = pt.basic.infer_static_shape(var.shape) | ||
if any(s is None for s in shape): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism
PyTensor only needs to know the dtype and rank of the outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it fail or does it infer they must match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails:
u, v = export.symbolic_shape("u, v")
x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)
export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).
I kept most of what was in the original PR, but made a few changes:
I think the biggest problem right now is that the If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with I'm not sure right now what the best way to handle this is. |
Yes, sure. Sorry that I dropped the ball. |
I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values |
I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function. |
No worries :-)
That works for cases where the shape is in principle known, but pytensor missed it during the static shape inference. In the case where the input shapes only depend on random variables (which will usually be the case in pymc models), we Ricardo realized that we can just eval the shapes once. It should be also be fine if we find just some valid input shapes, if the shapes change later that shouldn't be a problem. We only have to set the output shapes to None in that case, because we don't actually know what the output shapes should be.
I thought it might have been something like that. I think that makes sense, but for now I'd rather get this in as-is. We can always revisit this later if we think we need it. |
We also should decide on a name: I think |
) | ||
|
||
# Create VJP operation | ||
vjp_op = JAXOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTensor won't be able to fuse duplicate applications of the grad (if they somehow exist) as it will create a new function under the hood and that's used for the Op equality (whatever is on __props__
is used)
Not a blocker, just laying it out. Alternatively if the grad was an Op parametrized by the original Op/func and the connected gradients it could be merged.
However, if running on jax backed, JAX itself may be able to avoid duplication.
I would err on keeping it simple for now like you did
pytensor/link/jax/ops.py
Outdated
[], | ||
resolved_input_shapes, | ||
on_unused_input="ignore", | ||
mode="FAST_COMPILE", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mode="FAST_COMPILE", | |
mode=Mode(linker="py", optimizer="fast_compile"), |
No reason for C code I guess
I like They do the same except in this case we also know how to generate gradients and dispatch to jax? |
Alternatively, |
I don't think I see the argument of it being like Other options: I think something that describes what it does, instead of how it does it would be much friendlier? |
+1 for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments, besides that we should decide on the name. Don't love jax/ops.py
filename either.
pytensor/link/jax/ops.py
Outdated
for i, result in enumerate(results): | ||
outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small opt
for i, result in enumerate(results): | |
outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype) | |
for out_container, result, out_type in zip(outputs, results, self.output_types): | |
out_container[0] = np.array(result, dtype=out_type.dtype) |
|
||
def make_node(self, *inputs: Variable) -> Apply: | ||
"""Create an Apply node with the given inputs and inferred outputs.""" | ||
outputs = [output_type() for output_type in self.output_types] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert the number of inputs is as expected?
def __repr__(self): | ||
base = self.__class__.__name__ | ||
if self.name is not None: | ||
base = f"{base}{self.name}" | ||
props = list(self.__props__) | ||
props.remove("name") | ||
props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props) | ||
return f"{base}({props})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stick with default __repr__
that already exists based on props?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The getattr(self, prop, ?)
shouldn't be needed. If a prop is missing from an Op we have bigger issues to worry about (equality/ hashing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it would be good if the name of the jax function appears in the node name somewhere. Otherwise it can be tricky to figure out which jaxop is which?
pytensor/link/jax/ops.py
Outdated
>>> | ||
>>> # Create the input and output types, input has a dynamic shape. | ||
>>> input_type = TensorType("float32", shape=(None,)) | ||
>>> output_type = TensorType("float32", shape=(1,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the output_type a scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad llm ;-)
|
||
def make_node(self, *inputs: Variable) -> Apply: | ||
"""Create an Apply node with the given inputs and inferred outputs.""" | ||
outputs = [output_type() for output_type in self.output_types] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_node should convert the inputs to tensors. inputs = [inp_type.filter_variable(inp) for inp, inp_type in zip(inputs, self.input_types)]
pytensor/link/jax/ops.py
Outdated
connected_output_indices = [] | ||
for i, output_grad in enumerate(output_gradients): | ||
if not isinstance(output_grad.type, DisconnectedType): | ||
connected_output_indices.append(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list_comp?
pytensor/link/jax/ops.py
Outdated
gradient_outputs = vjp_op( | ||
*[*inputs, *[output_gradients[i] for i in connected_output_indices]] | ||
) | ||
if not isinstance(gradient_outputs, Sequence): | ||
gradient_outputs = [gradient_outputs] | ||
return gradient_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gradient_outputs = vjp_op( | |
*[*inputs, *[output_gradients[i] for i in connected_output_indices]] | |
) | |
if not isinstance(gradient_outputs, Sequence): | |
gradient_outputs = [gradient_outputs] | |
return gradient_outputs | |
return = vjp_op( | |
*[*inputs, *[output_gradients[i] for i in connected_output_indices]], | |
return_list=True, | |
) |
pytensor/link/jax/ops.py
Outdated
return gradient_outputs | ||
|
||
|
||
def as_jax_op(jax_function=None, *, allow_eval=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should open an issue to allow itype/otypes like as_op
already does (unless you want to implement it yourself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... return { | ||
... "sum": jnp.add(x, y) * scale, | ||
... } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this example still correct? We wrap things back in pytree after the fact? It seems we do.
But we should be running these as part of doctest, and we're not: https://github.com/pymc-devs/pytensor/actions/runs/17780488500/job/50538346433?pr=1614
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it doesn't run because as_jax_op
is not imported by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an idea about how to fix this? I never really looked into how the doctests are executed.
pytensor/link/jax/ops.py
Outdated
Notes | ||
----- | ||
The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, | ||
available at | ||
`pymc-labs.io <https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick | ||
-examples/>`__. | ||
To accept functions and non-PyTensor variables as input, the function uses | ||
:func:`equinox.partition` and :func:`equinox.combine` to split and combine the | ||
variables. Shapes are inferred using | ||
:func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
@ricardoV94 I (hopefully) address all comments I didn't reply to directly. On top of that, I also removed the equinox dependency. We only needed two small functions from it, so I just copied those over with a note in the source. It's Apache 2.0, so I think that's ok? (Also, it's just a few lines anyway...). I still like the name |
revisit #1120, which seems abandoned.
@jdehning I hope it is ok if I continue this PR?
📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/