diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0e6c9ee0f2..664ba0fdc5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -208,7 +208,7 @@ jobs: micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi - if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi + if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi diff --git a/doc/conf.py b/doc/conf.py index e10dcffb90..48d81730ba 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,6 +38,7 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/doc/environment.yml b/doc/environment.yml index 7b564e8fb0..5b1f8790dc 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -25,4 +25,4 @@ dependencies: - ablog - pip - pip: - - -e .. + - -e ..[jax] diff --git a/doc/extending/creating_an_op.rst b/doc/extending/creating_an_op.rst index b9aa77f81f..3da251155f 100644 --- a/doc/extending/creating_an_op.rst +++ b/doc/extending/creating_an_op.rst @@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d :download:`Solution` -:func:`as_op` +:func:`wrap_py` ------------- -:func:`as_op` is a Python decorator that converts a Python function into a +:func:`wrap_py` is a Python decorator that converts a Python function into a basic PyTensor :class:`Op` that will call the supplied function during execution. This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation. @@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature inputs PyTensor variables that were declared. .. note:: - The python function wrapped by the :func:`as_op` decorator needs to return a new + The python function wrapped by the :func:`wrap_py` decorator needs to return a new data allocation, no views or in place modification of the input. -:func:`as_op` Example +:func:`wrap_py` Example ^^^^^^^^^^^^^^^^^^^^^ .. testcode:: asop @@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature import pytensor.tensor as pt import numpy as np from pytensor import function - from pytensor.compile.ops import as_op + from pytensor.compile.ops import wrap_py def infer_shape_numpy_dot(fgraph, node, input_shapes): ashp, bshp = input_shapes return [ashp[:-1] + bshp[-1:]] - @as_op( + @wrap_py( itypes=[pt.dmatrix, pt.dmatrix], otypes=[pt.dmatrix], infer_shape=infer_shape_numpy_dot, diff --git a/doc/extending/extending_pytensor_solution_1.py b/doc/extending/extending_pytensor_solution_1.py index ff470ec420..c5a2f8b4e5 100644 --- a/doc/extending/extending_pytensor_solution_1.py +++ b/doc/extending/extending_pytensor_solution_1.py @@ -167,9 +167,9 @@ def test_infer_shape(self): import numpy as np -# as_op exercice +# wrap_py exercice import pytensor -from pytensor.compile.ops import as_op +from pytensor.compile.ops import wrap_py def infer_shape_numpy_dot(fgraph, node, input_shapes): @@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes): return [ashp[:-1] + bshp[-1:]] -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_dot, @@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes): return [ashp[0]] -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_add_sub, @@ -201,7 +201,7 @@ def numpy_add(a, b): return np.add(a, b) -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_add_sub, diff --git a/doc/library/index.rst b/doc/library/index.rst index e9b362f8db..63cf7572a6 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -61,10 +61,16 @@ Convert to Variable .. autofunction:: pytensor.as_symbolic(...) +Wrap JAX functions +================== + +.. autofunction:: wrap_jax(...) + + Alias for :func:`pytensor.link.jax.ops.wrap_jax` + Debug ===== .. autofunction:: pytensor.dprint(...) Alias for :func:`pytensor.printing.debugprint` - diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 3c925ac2f2..12f67c9a37 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -166,7 +166,7 @@ def get_underlying_scalar_constant(v): from pytensor.scan.basic import scan from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph - +from pytensor.link.jax.ops import wrap_jax # isort: on diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index f6a95fe163..8c7fe5f396 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -56,6 +56,7 @@ register_deep_copy_op_c_code, register_view_op_c_code, view_op, + wrap_py, ) from pytensor.compile.profiling import ProfileStats from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index a4eba4079f..72b1447b32 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -1,6 +1,6 @@ """ This file contains auxiliary Ops, used during the compilation phase and Ops -building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that +building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that help make new Ops more rapidly. """ @@ -268,12 +268,12 @@ def __reduce__(self): obj = load_back(mod, name) except (ImportError, KeyError, AttributeError): raise pickle.PicklingError( - f"Can't pickle as_op(), not found as {mod}.{name}" + f"Can't pickle wrap_py(), not found as {mod}.{name}" ) else: if obj is not self: raise pickle.PicklingError( - f"Can't pickle as_op(), not the object at {mod}.{name}" + f"Can't pickle wrap_py(), not the object at {mod}.{name}" ) return load_back, (mod, name) @@ -282,6 +282,18 @@ def _infer_shape(self, fgraph, node, input_shapes): def as_op(itypes, otypes, infer_shape=None): + import warnings + + warnings.warn( + "pytensor.as_op is deprecated and will be removed in a future release. " + "Please use pytensor.wrap_py instead.", + DeprecationWarning, + stacklevel=2, + ) + return wrap_py(itypes, otypes, infer_shape) + + +def wrap_py(itypes, otypes, infer_shape=None): """ Decorator that converts a function into a basic PyTensor op that will call the supplied function as its implementation. @@ -301,8 +313,8 @@ def infer_shape(fgraph, node, input_shapes): Examples -------- - @as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], - otypes=[pytensor.tensor.fmatrix]) + @wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], + otypes=[pytensor.tensor.fmatrix]) def numpy_dot(a, b): return numpy.dot(a, b) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index 66eb647cca..4735f9aa98 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -13,6 +13,7 @@ from pytensor.graph import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse +from pytensor.link.jax.ops import JAXOp from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise @@ -142,3 +143,8 @@ def opfromgraph(*inputs): return fgraph_fn(*inputs) return opfromgraph + + +@jax_funcify.register(JAXOp) +def jax_op_funcify(op, **kwargs): + return op.perform_jax diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..dde60f8e57 --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,520 @@ +"""Convert a jax function to a pytensor compatible function.""" + +from collections.abc import Sequence +from functools import wraps + +import numpy as np + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op, Variable +from pytensor.tensor.basic import infer_static_shape +from pytensor.tensor.type import TensorType + + +class JAXOp(Op): + """ + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward + computation and reverse-mode differentiation (via VJP). + + Parameters + ---------- + input_types : list + A list of PyTensor types for each input variable. + output_types : list + A list of PyTensor types for each output variable. + jax_function : callable + The JAX function that computes outputs from inputs. It should + always return a tuple of outputs, even if there is only one output. + name : str, optional + A custom name for the Op instance. If provided, the class name will be + updated accordingly. + + Example + ------- + This example defines a simple function that sums the input array with a dynamic shape. + + >>> import numpy as np + >>> import jax + >>> import jax.numpy as jnp + >>> from pytensor.tensor import TensorType + >>> + >>> # Create the jax function that sums the input array. + >>> def sum_function(x, y): + ... return jnp.sum(x + y) + >>> + >>> # Create the input and output types, input has a dynamic shape. + >>> input_type = TensorType("float32", shape=(None,)) + >>> output_type = TensorType("float32", shape=()) + >>> + >>> # Instantiate a JAXOp + >>> op = JAXOp( + ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" + ... ) + >>> # Define symbolic input variables. + >>> x = pt.tensor("x", dtype="float32", shape=(2,)) + >>> y = pt.tensor("y", dtype="float32", shape=(2,)) + >>> # Compile a PyTensor function. + >>> result = op(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print( + ... f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array(14., dtype=float32)] + >>> + >>> # Compute the gradient of op(x, y) with respect to x. + >>> g = pt.grad(result[0], x) + >>> grad_f = pytensor.function([x, y], [g]) + >>> print( + ... grad_f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array([1., 1.], dtype=float32)] + """ + + __props__ = ("input_types", "output_types", "jax_func") + + def __init__(self, input_types, output_types, jax_function, name=None): + import jax + + self.input_types = tuple(input_types) + self.output_types = tuple(output_types) + self.jax_func = jax_function + self.jitted_func = jax.jit(jax_function) + self.name = name + super().__init__() + + def __repr__(self): + base = self.__class__.__name__ + props = list(self.__props__) + if self.name is not None: + props.insert(0, "name") + props = ", ".join(f"{prop}={getattr(self, prop)}" for prop in props) + return f"{base}({props})" + + def make_node(self, *inputs: Variable) -> Apply: + """Create an Apply node with the given inputs and inferred outputs.""" + if len(inputs) != len(self.input_types): + raise ValueError( + f"Op {self} expected {len(self.input_types)} inputs, got {len(inputs)}" + ) + filtered_inputs = [ + inp_type.filter_variable(inp) + for inp, inp_type in zip(inputs, self.input_types) + ] + outputs = [output_type() for output_type in self.output_types] + return Apply(self, filtered_inputs, outputs) + + def perform(self, node, inputs, outputs): + """Execute the JAX function and store results in output storage.""" + results = self.jitted_func(*inputs) + if not isinstance(results, tuple): + raise TypeError("JAX function must return a tuple of outputs.") + if len(results) != len(outputs): + raise ValueError( + f"JAX function returned {len(results)} outputs, but " + f"{len(outputs)} were expected." + ) + for output_container, result, out_type in zip( + outputs, results, self.output_types + ): + output_container[0] = np.array(result, dtype=out_type.dtype) + + def perform_jax(self, *inputs): + """Execute the JAX function directly, returning JAX arrays.""" + outputs = self.jitted_func(*inputs) + if not isinstance(outputs, tuple): + raise TypeError("JAX function must return a tuple of outputs.") + if len(outputs) == 1: + return outputs[0] + return outputs + + def grad(self, inputs, output_gradients): + """Compute gradients using JAX's vector-Jacobian product (VJP).""" + import jax + + # Find indices of outputs that need gradients + connected_output_indices = [ + i + for i, output_grad in enumerate(output_gradients) + if not isinstance(output_grad.type, DisconnectedType) + ] + + num_inputs = len(inputs) + + def vjp_operation(*args): + """VJP operation that computes gradients w.r.t. inputs.""" + input_values = args[:num_inputs] + cotangent_vectors = args[num_inputs:] + assert len(cotangent_vectors) == len(connected_output_indices) + + def restricted_function(*input_values): + """Restricted function that only returns connected outputs.""" + outputs = self.jax_func(*input_values) + return [ + outputs[i].astype(self.output_types[i].dtype) + for i in connected_output_indices + ] + + _primals, vjp_function = jax.vjp(restricted_function, *input_values) + output_dtypes = [ + self.output_types[i].dtype for i in connected_output_indices + ] + return vjp_function( + [ + cotangent.astype(dtype) + for cotangent, dtype in zip( + cotangent_vectors, output_dtypes, strict=True + ) + ] + ) + + if self.name is not None: + name = "vjp_" + self.name + else: + name = "vjp_jax_op" + + # Create VJP operation + vjp_op = JAXOp( + self.input_types + + tuple(self.output_types[i] for i in connected_output_indices), + [self.input_types[i] for i in range(num_inputs)], + vjp_operation, + name=name, + ) + + return vjp_op( + *[*inputs, *[output_gradients[i] for i in connected_output_indices]], + return_list=True, + ) + + +def wrap_jax(jax_function=None, *, allow_eval=True): + """Return a PyTensor-compatible function from a JAX jittable function. + + This decorator wraps a JAX function so that it accepts and returns + `pytensor.Variable` objects. The JAX-jittable function can accept any + nested Python structure (a `Pytree + `_) as input, and might + return any nested Python structure. + + Parameters + ---------- + jax_function : Callable, optional + A JAX function to be wrapped. If None, returns a decorator function. + allow_eval : bool, default=True + Whether to allow evaluation of symbolic shapes when input shapes are + not fully determined. + + Returns + ------- + Callable + A function that wraps the given JAX function so that it can be called with + pytensor.Variable inputs and returns pytensor.Variable outputs. + + Examples + -------- + + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @wrap_jax + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.scalar("x") + >>> y = pt.scalar("y") + >>> result = add(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print(f(1, 2)) + [array(3.)] + + We can also pass arbitrary jax pytree structures as inputs and outputs: + + >>> import jax + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @wrap_jax + ... def complex_function(x, y, scale=1.0): + ... return { + ... "sum": jnp.add(x, y) * scale, + ... } + >>> x = pt.vector("x") + >>> y = pt.vector("y") + >>> result = complex_function(x, y, scale=2.0) + >>> f = pytensor.function([x, y], [result["sum"]]) + + Or Equinox modules: + + >>> x = pt.tensor("x", shape=(3,)) # doctest +SKIP + >>> y = pt.tensor("y", shape=(3,)) # doctest +SKIP + >>> import equinox as eqx # doctest +SKIP + >>> mlp = eqx.nn.MLP( + ... 3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0) + ... ) # doctest +SKIP + >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) # doctest +SKIP + >>> @wrap_jax # doctest +SKIP + ... def neural_network(x, mlp): # doctest +SKIP + ... return mlp(x) # doctest +SKIP + >>> out = neural_network(x, mlp) # doctest +SKIP + """ + + def decorator(func): + name = func.__name__ + + try: + import jax + except ImportError as e: + raise ImportError( + "The wrap_jax decorator requires jax to be installed." + ) from e + + @wraps(func) + def wrapper(*args, **kwargs): + # Partition inputs into dynamic PyTensor variables and static variables. + # Static variables don't participate in the computational graph. + pytensor_variables, static_values = _eqx_partition( + (args, kwargs), lambda x: isinstance(x, Variable) + ) + + # Flatten the PyTensor variables for processing + variables_flat, variables_treedef = jax.tree.flatten(pytensor_variables) + input_types = [var.type for var in variables_flat] + + # Determine output types by calling the function through jax.eval_shape + output_types, output_treedef, output_static = _find_output_types( + func, + variables_flat, + variables_treedef, + static_values, + allow_eval=allow_eval, + ) + + def flattened_function(*flat_variables): + """Execute the original function with flattened inputs.""" + variables = jax.tree.unflatten(variables_treedef, flat_variables) + reconstructed_args, reconstructed_kwargs = _eqx_combine( + variables, static_values + ) + function_outputs = func(*reconstructed_args, **reconstructed_kwargs) + array_outputs, _ = _eqx_partition(function_outputs, _is_array) + flattened_outputs, _ = jax.tree.flatten(array_outputs) + return tuple(flattened_outputs) + + # Create the JAX operation + jax_op_instance = JAXOp( + input_types, + output_types, + flattened_function, + name=name, + ) + + # Execute the operation and reconstruct the output structure + flattened_results = jax_op_instance(*variables_flat) + if not isinstance(flattened_results, Sequence): + flattened_results = [flattened_results] + + output_variables = jax.tree.unflatten(output_treedef, flattened_results) + final_outputs = _eqx_combine(output_variables, output_static) + + return final_outputs + + return wrapper + + if jax_function is None: + return decorator + else: + return decorator(jax_function) + + +def _find_output_types( + jax_function, inputs_flat, input_treedef, static_input, *, allow_eval=True +): + """Determine output types with jax.eval_shape on dummy inputs.""" + import jax + import jax.numpy as jnp + + resolved_input_shapes = [] + requires_shape_evaluation = False + + for variable in inputs_flat: + # If shape is already fully determined, use it directly + if not any(dimension is None for dimension in variable.type.shape): + resolved_input_shapes.append(variable.type.shape) + continue + + # Try to infer static shape + _, inferred_shape = infer_static_shape(variable.shape) + if not any(dimension is None for dimension in inferred_shape): + resolved_input_shapes.append(inferred_shape) + continue + + # Shape still has undetermined dimensions + if not allow_eval: + raise ValueError( + f"Input variable {variable} has undetermined shape dimensions. " + "Please provide inputs with fully determined shapes by calling " + "pt.specify_shape." + ) + requires_shape_evaluation = True + resolved_input_shapes.append(variable.shape) + + if requires_shape_evaluation: + try: + shape_evaluation_function = function( + [], + resolved_input_shapes, + on_unused_input="ignore", + mode=Mode(linker="py", optimizer="fast_compile"), + ) + except Exception as e: + raise ValueError( + "Could not compile a function to infer example shapes. " + "Please provide inputs with fully determined shapes by " + "calling pt.specify_shape." + ) from e + resolved_input_shapes = shape_evaluation_function() + + # Determine output types using jax.eval_shape with dummy inputs + output_metadata_storage = {} + + dummy_input_arrays = [ + jnp.ones(shape, dtype=variable.type.dtype) + for variable, shape in zip(inputs_flat, resolved_input_shapes, strict=True) + ] + + def wrapped_jax_function(input_arrays): + """Wrapper to extract output metadata during shape evaluation.""" + variables = jax.tree.unflatten(input_treedef, input_arrays) + reconstructed_args, reconstructed_kwargs = _eqx_combine(variables, static_input) + function_outputs = jax_function(*reconstructed_args, **reconstructed_kwargs) + array_outputs, static_outputs = _eqx_partition(function_outputs, _is_array) + + # Store metadata for later use + output_metadata_storage["output_static"] = static_outputs + flattened_outputs, output_structure = jax.tree.flatten(array_outputs) + output_metadata_storage["output_treedef"] = output_structure + return flattened_outputs + + output_shapes_flat = jax.eval_shape(wrapped_jax_function, dummy_input_arrays) + output_treedef = output_metadata_storage["output_treedef"] + output_static = output_metadata_storage["output_static"] + + # If we used shape evaluation, set all output shapes to unknown + if requires_shape_evaluation: + output_types = [ + TensorType( + dtype=output_shape.dtype, shape=tuple(None for _ in output_shape.shape) + ) + for output_shape in output_shapes_flat + ] + else: + output_types = [ + TensorType(dtype=output_shape.dtype, shape=output_shape.shape) + for output_shape in output_shapes_flat + ] + + return output_types, output_treedef, output_static + + +# From the equinox library, licensed under Apache 2.0 +# https://github.com/patrick-kidger/equinox +# +# Copied here to avoid a dependency on equinox just these functions. +def _eqx_combine(*pytrees, is_leaf=None): + """Combines multiple PyTrees into one PyTree, by replacing `None` leaves. + + !!! example + + ```python + pytree1 = [None, 1, 2] + pytree2 = [0, None, None] + equinox.combine(pytree1, pytree2) # [0, 1, 2] + ``` + + !!! tip + + The idea is that `equinox.combine` should be used to undo a call to + [`equinox.filter`][] or [`equinox.partition`][]. + + **Arguments:** + + - `*pytrees`: a sequence of PyTrees all with the same structure. + - `is_leaf`: As [`equinox.partition`][]. + + **Returns:** + + A PyTree with the same structure as its inputs. Each leaf will be the first + non-`None` leaf found in the corresponding leaves of `pytrees` as they are + iterated over. + """ + import jax + + if is_leaf is None: + _is_leaf = _is_none + else: + _is_leaf = lambda x: _is_none(x) or is_leaf(x) # noqa: E731 + + return jax.tree.map(_combine, *pytrees, is_leaf=_is_leaf) + + +def _eqx_partition( + pytree, + filter_spec, + replace=None, + is_leaf=None, +): + """Splits a PyTree into two pieces. Equivalent to + `filter(...), filter(..., inverse=True)`, but slightly more efficient. + + !!! info + + See also [`equinox.combine`][] to reconstitute the PyTree again. + """ + import jax + + filter_tree = jax.tree.map(_make_filter_tree(is_leaf), filter_spec, pytree) + left = jax.tree.map(lambda mask, x: x if mask else replace, filter_tree, pytree) + right = jax.tree.map(lambda mask, x: replace if mask else x, filter_tree, pytree) + return left, right + + +def _make_filter_tree(is_leaf): + import jax + import jax.core + + def _filter_tree(mask, arg): + if isinstance(mask, jax.core.Tracer): + raise ValueError("`filter_spec` leaf values cannot be traced arrays.") + if isinstance(mask, bool): + return jax.tree.map(lambda _: mask, arg, is_leaf=is_leaf) + elif callable(mask): + return jax.tree.map(mask, arg, is_leaf=is_leaf) + else: + raise ValueError( + "`filter_spec` must consist of booleans and callables only." + ) + + return _filter_tree + + +def _is_array(element) -> bool: + """Returns `True` if `element` is a JAX array or NumPy array.""" + import jax + + return isinstance(element, np.ndarray | np.generic | jax.Array) + + +def _combine(*args): + for arg in args: + if arg is not None: + return arg + return None + + +def _is_none(x): + return x is None diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index 461c7793ad..5b7a5ea24a 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -1,14 +1,15 @@ import pickle import numpy as np +import pytest from pytensor import function -from pytensor.compile.ops import as_op +from pytensor.compile.ops import as_op, wrap_py from pytensor.tensor.type import dmatrix, dvector from tests import unittest_tools as utt -@as_op([dmatrix, dmatrix], dmatrix) +@wrap_py([dmatrix, dmatrix], dmatrix) def mul(a, b): """ This is for test_pickle, since the function still has to be @@ -21,7 +22,7 @@ class TestOpDecorator(utt.InferShapeTester): def test_1arg(self): x = dmatrix("x") - @as_op(dmatrix, dvector) + @wrap_py(dmatrix, dvector) def cumprod(x): return np.cumprod(x) @@ -31,13 +32,28 @@ def cumprod(x): assert np.allclose(r, r0), (r, r0) + def test_deprecation(self): + x = dmatrix("x") + + with pytest.warns(DeprecationWarning): + + @as_op(dmatrix, dvector) + def cumprod(x): + return np.cumprod(x) + + fn = function([x], cumprod(x)) + r = fn([[1.5, 5], [2, 2]]) + r0 = np.array([1.5, 7.5, 15.0, 30.0]) + + assert np.allclose(r, r0), (r, r0) + def test_2arg(self): x = dmatrix("x") x.tag.test_value = np.zeros((2, 2)) y = dvector("y") y.tag.test_value = [0, 0, 0, 0] - @as_op([dmatrix, dvector], dvector) + @wrap_py([dmatrix, dvector], dvector) def cumprod_plus(x, y): return np.cumprod(x) + y @@ -57,7 +73,7 @@ def infer_shape(fgraph, node, shapes): x, y = shapes return [y] - @as_op([dmatrix, dvector], dvector, infer_shape) + @wrap_py([dmatrix, dvector], dvector, infer_shape) def cumprod_plus(x, y): return np.cumprod(x) + y diff --git a/tests/link/jax/test_wrap_jax.py b/tests/link/jax/test_wrap_jax.py new file mode 100644 index 0000000000..2052b5f4db --- /dev/null +++ b/tests/link/jax/test_wrap_jax.py @@ -0,0 +1,561 @@ +import numpy as np +import pytest + +from pytensor import config, grad, wrap_jax +from pytensor.compile.sharedvalue import shared +from pytensor.link.jax.ops import JAXOp +from pytensor.scalar import all_types +from pytensor.tensor import TensorType, tensor +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +def test_two_inputs_single_output(): + rng = np.random.default_rng(1) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y) + + # Test with wrap_jax decorator + out = wrap_jax(f)(x, y) + grad_out = grad(out.sum(), [x, y]) + + compare_jax_and_py([x, y], [out, *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + def f(x, y): + return (jax.nn.sigmoid(x + y),) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + out = jax_op(x, y) + grad_out = grad(out.sum(), [x, y]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + +def test_op_returns_list(): + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + + test_values = [np.ones((2,)).astype(config.floatX) for inp in (x, y)] + + def f(x, y): + return jax.nn.sigmoid(x + y) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + + with pytest.raises(TypeError, match="tuple of outputs"): + out = jax_op(x, y) + grad_out = grad(out.sum(), [x, y]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + +def test_two_inputs_tuple_output(): + rng = np.random.default_rng(2) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y), y * 2 + + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x, y) + grad_out = grad((out1 + out2).sum(), [x, y]) + + compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) + with jax.disable_jit(): + # must_be_device_array is False, because the with disabled jit compilation, + # inputs are not automatically transformed to jax.Array anymore + compare_jax_and_py( + [x, y], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x, y) + grad_out = grad((out1 + out2).sum(), [x, y]) + compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) + + +def test_two_inputs_list_output_one_unused_output(): + # One output is unused, to test whether the wrapper can handle DisconnectedType + rng = np.random.default_rng(3) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return (jax.nn.sigmoid(x + y), y * 2) + + # Test with wrap_jax decorator + out, _ = wrap_jax(f)(x, y) + grad_out = grad(out.sum(), [x, y]) + + compare_jax_and_py([x, y], [out, *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out, _ = jax_op(x, y) + grad_out = grad(out.sum(), [x, y]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + +def test_single_input_tuple_output(): + rng = np.random.default_rng(4) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x * 2 + + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) + + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(out1.sum(), [x]) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + + +def test_scalar_input_tuple_output(): + rng = np.random.default_rng(5) + x = tensor("x", shape=()) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x + + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) + + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(out1.sum(), [x]) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + + +def test_single_input_list_output(): + rng = np.random.default_rng(6) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return (jax.nn.sigmoid(x), 2 * x) + + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) + + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) + + # Test direct JAXOp usage, with unspecified output shapes + jax_op = JAXOp( + [x.type], + [ + TensorType(config.floatX, shape=(None,)), + TensorType(config.floatX, shape=(None,)), + ], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(out1.sum(), [x]) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) + + +def test_pytree_input_tuple_output(): + rng = np.random.default_rng(7) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + y_tmp = {"y": y, "y2": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @wrap_jax + def f(x, y): + return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + + # Test with wrap_jax decorator + out = f(x, y_tmp) + grad_out = grad(out[1].sum(), [x, y]) + + compare_jax_and_py([x, y], [out[0], out[1], *grad_out], test_values) + + with jax.disable_jit(): + compare_jax_and_py( + [x, y], [out[0], out[1], *grad_out], test_values, must_be_device_array=False + ) + + +def test_pytree_input_pytree_output(): + rng = np.random.default_rng(8) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @wrap_jax + def f(x, y): + return x, jax.tree_util.tree_map(lambda x: jax.numpy.exp(x), y) + + # Test with wrap_jax decorator + out = f(x, y_tmp) + grad_out = grad(out[1]["b"][0].sum(), [x, y]) + + compare_jax_and_py([x, y], [out[0], out[1]["a"], *grad_out], test_values) + + with jax.disable_jit(): + compare_jax_and_py( + [x, y], + [out[0], out[1]["a"], *grad_out], + test_values, + must_be_device_array=False, + ) + + +def test_pytree_input_with_non_graph_args(): + rng = np.random.default_rng(9) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @wrap_jax + def f(x, y, depth, which_variable): + if which_variable == "x": + var = x + elif which_variable == "y": + var = y["a"] + y["b"][0] + else: + return "Unsupported argument" + for _ in range(depth): + var = jax.nn.sigmoid(var) + return var + + # Test with wrap_jax decorator + # arguments depth and which_variable are not part of the graph + out = f(x, y_tmp, depth=3, which_variable="x") + grad_out = grad(out.sum(), [x]) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) + + out = f(x, y_tmp, depth=7, which_variable="y") + grad_out = grad(out.sum(), [x]) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) + with jax.disable_jit(): + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) + + out = f(x, y_tmp, depth=10, which_variable="z") + assert out == "Unsupported argument" + + +def test_unused_matrix_product(): + # A matrix output is unused, to test whether the wrapper can handle a + # DisconnectedType with a larger dimension. + + rng = np.random.default_rng(10) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return x[:, None] @ y[None], jax.numpy.exp(x) + + # Test with wrap_jax decorator + out = wrap_jax(f)(x, y) + grad_out = grad(out[1].sum(), [x]) + + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) + + with jax.disable_jit(): + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [ + TensorType(config.floatX, shape=(3, 3)), + TensorType(config.floatX, shape=(3,)), + ], + f, + ) + out = jax_op(x, y) + grad_out = grad(out[1].sum(), [x]) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) + + +def test_unknown_static_shape(): + rng = np.random.default_rng(11) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x_cumsum = x.cumsum() # Now x_cumsum has an unknown shape + + def f(x, y): + return (x * jax.numpy.ones(3),) + + (out,) = wrap_jax(f)(x_cumsum, y) + grad_out = grad(out.sum(), [x]) + + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + with jax.disable_jit(): + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(None,))], + f, + ) + out = jax_op(x_cumsum, y) + grad_out = grad(out.sum(), [x]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + +def test_nn(): + eqx = pytest.importorskip("equinox") + nn = pytest.importorskip("equinox.nn") + + rng = np.random.default_rng(13) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + mlp = nn.MLP(3, 3, 3, depth=2, activation=jax.numpy.tanh, key=jax.random.key(0)) + mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) + + @wrap_jax + def f(x, mlp): + return mlp(x) + + out = f(x, mlp) + grad_out = grad(out.sum(), [x]) + + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + with jax.disable_jit(): + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + +def test_no_inputs(): + def f(): + return jax.numpy.array(42.0) + + out = wrap_jax(f)() + assert out.eval() == 42.0 + + +def test_unknown_shape(): + x = tensor("x", shape=(None,)) + + def f(x): + return x * 2 + + with pytest.raises(ValueError, match="Please provide inputs"): + wrap_jax(f)(x) + + +def test_unknown_shape_with_eval(): + x = shared(np.ones(3)) + assert x.type.shape == (None,) + + def f(x): + return x * 2 + + out = wrap_jax(f)(x) + grad_out = grad(out.sum(), [x]) + + compare_jax_and_py([], [out, *grad_out], []) + + with jax.disable_jit(): + compare_jax_and_py([], [out, *grad_out], [], must_be_device_array=False) + + with pytest.raises(ValueError, match="Please provide inputs"): + wrap_jax(f, allow_eval=False)(x) + + +def test_decorator_forms(): + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + + @wrap_jax + def the_name1(x, y): + return (x + y).sum() + + @wrap_jax(allow_eval=True) + def the_name2(x, y): + return (x + y).sum() + + the_name1(x, y) + the_name2(x, y) + + +def test_repr(): + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + + def the_name(x, y): + return (x + y).sum() + + jax_op = wrap_jax(the_name) + assert "the_name" in repr(jax_op(x, y).owner.op) + + (grad_x, _) = grad(jax_op(x, y), [x, y]) + assert "vjp_the_name" in repr(grad_x.owner.op) + + +class TestDtypes: + @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) + def test_different_in_output(self, in_dtype, out_dtype): + x = tensor("x", shape=(3,), dtype=in_dtype) + y = tensor("y", shape=(3,), dtype=in_dtype) + + if "int" in in_dtype: + test_values = [ + np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + else: + test_values = [ + np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + + @wrap_jax + def f(x, y): + out = jax.numpy.add(x, y) + return jax.numpy.real(out).astype(out_dtype) + + out = f(x, y) + assert out.dtype == out_dtype + + if "float" in in_dtype and "float" in out_dtype: + grad_out = grad(out[0], [x, y]) + assert grad_out[0].dtype == in_dtype + compare_jax_and_py([x, y], [out, *grad_out], test_values) + else: + compare_jax_and_py([x, y], [out], test_values) + + with jax.disable_jit(): + if "float" in in_dtype and "float" in out_dtype: + compare_jax_and_py([x, y], [out, *grad_out], test_values) + else: + compare_jax_and_py([x, y], [out], test_values) + + @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) + def test_test_different_inputs(self, in1_dtype, in2_dtype): + x = tensor("x", shape=(3,), dtype=in1_dtype) + y = tensor("y", shape=(3,), dtype=in2_dtype) + + if "int" in in1_dtype: + test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] + else: + test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)] + if "int" in in2_dtype: + test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype)) + else: + test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) + + @wrap_jax + def f(x, y): + out = jax.numpy.add(x, y) + return jax.numpy.real(out).astype(in1_dtype) + + out = f(x, y) + assert out.dtype == in1_dtype + + if "float" in in1_dtype and "float" in in2_dtype: + # In principle, the gradient should also be defined if the second input is + # an integer, but it doesn't work for some reason. + grad_out = grad(out[0], [x]) + assert grad_out[0].dtype == in1_dtype + inputs = [x, y] + outputs = [out, *grad_out] + else: + inputs = [x, y] + outputs = [out] + + compare_jax_and_py(inputs, outputs, test_values) + + with jax.disable_jit(): + if "float" in in1_dtype and "float" in in2_dtype: + compare_jax_and_py([x, y], [out, *grad_out], test_values) + else: + compare_jax_and_py([x, y], [out], test_values)