Mjinx is a library for auto-differentiable numerical inverse kinematics, powered by JAX and Mujoco MJX. The library was heavily inspired by the similar Pinocchio-based tool pink and Mujoco-based analogue mink.
- Flexibility. Each control problem is assembled via
Components
, which enforce desired behaviour or keeps system in a safety set. - Different solution approaches.
JAX
(i.e. it's efficient sampling and autodifferentiation) allows to implement variety of solvers, which might be more beneficial in different scenarios. - Fully Jax-compatible. Both optimal control problem and its solver are jax-compatible: jit-compilation and automatic vectorization are available for the whole problem.
- Convinience. The functionality is nicely wrapped to make the interaction with it easier.
The package is available in PyPI registry, and could be installed via pip
:
pip install mjinx
Different installation versions:
- Visualization tool
mjinx.visualization.BatchVisualizer
is available inmjinx[visual]
- To run examples, install
mjinx[examples]
- To install development version, install
mjinx[dev]
(preferably in editable mode) - To build docs, install
mjinx[docs]
- To install the repository with all dependencies, install
mjinx[all]
Here is the example of mjinx
usage:
from mujoco import mjx mjx
from mjinx.problem import Problem
# Initialize the robot model using MuJoCo
MJCF_PATH = "path_to_mjcf.xml"
mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
mjx_model = mjx.put_model(mj_model)
# Create instance of the problem
problem = Problem(mjx_model)
# Add tasks to track desired behavior
frame_task = FrameTask("ee_task", cost=1, gain=20, body_name="link7")
problem.add_component(frame_task)
# Add barriers to keep robot in a safety set
joints_barrier = JointBarrier("jnt_range", gain=10)
problem.add_component(joints_barrier)
# Initialize the solver
solver = LocalIKSolver(mjx_model)
# Initializing initial condition
q0 = np.zeros(7)
# Initialize solver data
solver_data = solver.init()
# jit-compiling solve and integrate
solve_jit = jax.jit(solver.solve)
integrate_jit = jax.jit(integrate, static_argnames=["dt"])
# === Control loop ===
for t in np.arange(0, 5, 1e-2):
# Changing problem and compiling it
frame_task.target_frame = np.array([0.1 * np.sin(t), 0.1 * np.cos(t), 0.1, 1, 0, 0,])
problem_data = problem.compile()
# Solving the instance of the problem
opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
# Integrating
q = integrate_jit(
mjx_model,
q,
opt_solution.v_opt,
dt,
)
The list of examples includes:
Kuka iiwa
local inverse kinematics (single item, vmap over desired trajectory)Kuka iiwa
global inverse kinematics (single item, vmap over desired trajectory)Go2
batched squats example
Note: The Global IK functionality is currently under development and not yet working properly as expected. It needs proper tuning and will be fixed in future updates. Use the Global IK examples with caution and expect suboptimal results.
If you use MJINX in your research, please cite it as follows:
@software{mjinx25,
author = {Domrachev, Ivan and Nedelchev, Simeon},
license = {MIT},
month = mar,
title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}},
url = {https://github.com/based-robotics/mjinx},
version = {0.1.1},
year = {2025}
}
We welcome suggestions and contributions. Please see our CONTRIBUTING.md file for guidelines.
I am deeply grateful to Simeon Nedelchev, whose guidance and expertise were instrumental in bringing this project to life.
This work draws significant inspiration from pink
by Stéphane Caron and mink
by Kevin Zakka. Their pioneering work in robotics and open source has been a guiding light for this project.
The codebase incorporates utility functions from MuJoCo MJX
. Beyond being an excellent tool for batched computations and machine learning, MJX's codebase serves as a masterclass in clean, informative implementation of physical simulations and JAX usage.
Special thanks to IRIS lab for their support.