This repository implements a variety of tools for the differential programming library JAX.
Tjax's major components are:
- A dataclass decorator
dataclass
that facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:- the ability to mark fields as static (not available in chex.dataclass), and
- a display method that produces formatted text according to the tree structure.
- A shim for the gradient
transformation library optax that supports:
- easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
- gradient transformation objects that can be passed dynamically to jitted functions, and
- generic type annotations.
- A pretty printer
print_generic
for aggregate and vector types, including dataclasses. (See display.) It features:- support for traced values,
- colorized tree output for aggregate structures, and
- formatted tabular output for arrays (or statistics when there's no room for tabular output).
Tjax also includes:
- Versions of
custom_vjp
andcustom_jvp
that support being used on methods:custom_vjp_method
andcustom_vjp_method
(See shims.) - Tools for working with cotangents. (See cotangent_tools.)
- JAX tree registration for NetworkX graph types. (See graph.)
- Leaky integration
leaky_integrate
and Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate
. (See leaky_integral.) - An improved version of
jax.tree_util.Partial
. (See partial.) - A testing function
assert_tree_allclose
that automatically produces testing code. And, a related functiontree_allclose
. (See testing.) - Basic tools like
divide_where
. (See tools.)
- Conventions: PEP8.
- How to run tests:
pytest .
- How to clean the source:
ruff check .
pyright
mypy
isort .
pylint tjax tests