Skip to content

LaplaceKorea/awesome-jax

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

86 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Awesome JAX AwesomeJAX Logo

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Contents

  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity.
    • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax - Has an object oriented design similar to PyTorch.
    • Elegy - A framework-agnostic Trainer interface for the Jax ecosystem. Supports Flax, Haiku, and Optax.
    • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
    • Jraph - Lightweight graph neural network library.
    • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
  • NumPyro - Probabilistic programming based on the Pyro library.
  • Chex - Utilities to write and test reliable JAX code.
  • Optax - Gradient processing and optimization library.
  • RLax - Library for implementing reinforcement learning agents.
  • JAX, M.D. - Accelerated, differential molecular dynamics.
  • Coax - Turn RL papers into code, the easy way.
  • SymJAX - Symbolic CPU/GPU/TPU programming.
  • mcx - Express & compile probabilistic programs for performant inference.
  • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers - Construct differentiable convex optimization layers.
  • TensorLy - Tensor learning made simple.

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku.
    • Equivariant MLP - Construct equivariant neural network layers.
    • jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications.
  • jax-flows - Normalizing flows in JAX.
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX.
  • jax-cosmo - Differentiable cosmology library.
  • efax - Exponential Families in JAX.
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
  • imax - Image augmentations and transformations.
  • FlaxVision - Flax version of TorchVision.
  • Oryx - Probabilistic programming language based on program transformations.
  • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Contributing

Contributions welcome! Read the contribution guidelines first.

Releases

No releases published

Packages

No packages published