Skip to content

based-robotics/mjinx

Repository files navigation

MJINX

mypy ruff docs PyPI version PyPI downloads Colab

Mjinx is a library for auto-differentiable numerical inverse kinematics, powered by JAX and Mujoco MJX. The library was heavily inspired by the similar Pinocchio-based tool pink and Mujoco-based analogue mink.

Key features

  1. Flexibility. Each control problem is assembled via Components, which enforce desired behaviour or keeps system in a safety set.
  2. Different solution approaches. JAX (i.e. it's efficient sampling and autodifferentiation) allows to implement variety of solvers, which might be more beneficial in different scenarios.
  3. Fully Jax-compatible. Both optimal control problem and its solver are jax-compatible: jit-compilation and automatic vectorization are available for the whole problem.
  4. Convinience. The functionality is nicely wrapped to make the interaction with it easier.

Installation

The package is available in PyPI registry, and could be installed via pip:

pip install mjinx

Different installation versions:

  1. Visualization tool mjinx.visualization.BatchVisualizer is available in mjinx[visual]
  2. To run examples, install mjinx[examples]
  3. To install development version, install mjinx[dev] (preferably in editable mode)
  4. To build docs, install mjinx[docs]
  5. To install the repository with all dependencies, install mjinx[all]

Usage

Here is the example of mjinx usage:

from mujoco import mjx mjx
from mjinx.problem import Problem

# Initialize the robot model using MuJoCo
MJCF_PATH = "path_to_mjcf.xml"
mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
mjx_model = mjx.put_model(mj_model)

# Create instance of the problem
problem = Problem(mjx_model)

# Add tasks to track desired behavior
frame_task = FrameTask("ee_task", cost=1, gain=20, body_name="link7")
problem.add_component(frame_task)

# Add barriers to keep robot in a safety set
joints_barrier = JointBarrier("jnt_range", gain=10)
problem.add_component(joints_barrier)

# Initialize the solver
solver = LocalIKSolver(mjx_model)

# Initializing initial condition
q0 = np.zeros(7)

# Initialize solver data
solver_data = solver.init()

# jit-compiling solve and integrate 
solve_jit = jax.jit(solver.solve)
integrate_jit = jax.jit(integrate, static_argnames=["dt"])

# === Control loop ===
for t in np.arange(0, 5, 1e-2):
    # Changing problem and compiling it
    frame_task.target_frame = np.array([0.1 * np.sin(t), 0.1 * np.cos(t), 0.1, 1, 0, 0,])
    problem_data = problem.compile()

    # Solving the instance of the problem
    opt_solution, solver_data = solve_jit(q, solver_data, problem_data)

    # Integrating
    q = integrate_jit(
        mjx_model,
        q,
        opt_solution.v_opt,
        dt,
    )

Examples

The list of examples includes:

  1. Kuka iiwa local inverse kinematics (single item, vmap over desired trajectory)
  2. Kuka iiwa global inverse kinematics (single item, vmap over desired trajectory)
  3. Go2 batched squats example

Note: The Global IK functionality is currently under development and not yet working properly as expected. It needs proper tuning and will be fixed in future updates. Use the Global IK examples with caution and expect suboptimal results.

Citing MJINX

If you use MJINX in your research, please cite it as follows:

@software{mjinx25,
  author = {Domrachev, Ivan and Nedelchev, Simeon},
  license = {MIT},
  month = mar,
  title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}},
  url = {https://github.com/based-robotics/mjinx},
  version = {0.1.1},
  year = {2025}
}

Contributing

We welcome suggestions and contributions. Please see our CONTRIBUTING.md file for guidelines.

Acknowledgements

I am deeply grateful to Simeon Nedelchev, whose guidance and expertise were instrumental in bringing this project to life.

This work draws significant inspiration from pink by Stéphane Caron and mink by Kevin Zakka. Their pioneering work in robotics and open source has been a guiding light for this project.

The codebase incorporates utility functions from MuJoCo MJX. Beyond being an excellent tool for batched computations and machine learning, MJX's codebase serves as a masterclass in clean, informative implementation of physical simulations and JAX usage.

Special thanks to IRIS lab for their support.

About

Numerical Inverse Kinematics solver based on JAX + MJX

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages