JaxSim is a differentiable physics engine built with JAX, tailored for co-design and robotic learning applications.
- Physically consistent differentiability w.r.t. hardware parameters.
- Closed chain dynamics support.
- Reduced-coordinate physics engine for fixed-base and floating-base robots.
- Fully Python-based, leveraging jax following a functional programming paradigm.
- Seamless execution on CPUs, GPUs, and TPUs.
- Supports JIT compilation and automatic vectorization for high performance.
- Compatible with SDF models and URDF (via sdformat conversion).
Warning
This project is still experimental, APIs could change between releases without notice.
Note
JaxSim currently focuses on locomotion applications. Only contacts between bodies and smooth ground surfaces are supported.
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description, time_step=0.0001, is_urdf=True
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Get the number of degrees of freedom
ndof = model.dofs()
# Initialize data and simulation
# Note that the default data representation is mixed velocity representation
data = js.data.JaxSimModelData.build(
model=model, base_position=jnp.array([0.0, 0.0, 1.0])
)
T = jnp.arange(start=0, stop=1.0, step=model.time_step)
tau = jnp.zeros(ndof)
# Simulate
for _ in T:
data = js.model.step(
model=model, data=data, link_forces=None, joint_force_references=tau
)
Check the example folder for additional use cases!
With conda
You can install the project using conda
as follows:
conda install jaxsim -c conda-forge
You can enforce GPU support, if needed, by also specifying "jaxlib = * = *cuda*"
.
With pixi
The minimum version of
pixi
required is0.39.0
.
Since the pixi.lock
file is stored using Git LFS, make sure you have Git LFS installed and properly configured on your system before installation. After cloning the repository, run:
git lfs install && git lfs pull
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can add the jaxsim dependency in pixi
project as follows:
pixi add jaxsim
If you are on Linux and you want to use a cuda
-powered version of jax
, remember to add the appropriate line in the system-requirements
table, i.e. adding
[system-requirements]
cuda = "12"
if you are using a pixi.toml
file or
[tool.pixi.system-requirements]
cuda = "12"
if you are using a pyproject.toml
file.
With pip
You can install the project using pypa/pip
, preferably in a virtual environment, as follows:
pip install jaxsim
Check pyproject.toml
for the complete list of optional dependencies.
You can obtain a full installation using jaxsim[all]
.
If you need GPU support, follow the official installation instructions of JAX.
Contributors installation (with conda
)
If you want to contribute to the project, we recommend creating the following jaxsim
conda environment first:
conda env create -f environment.yml
Then, activate the environment and install the project in editable mode:
conda activate jaxsim
pip install --no-deps -e .
Contributors installation (with pixi
)
The minimum version of
pixi
required is0.39.0
.
Since the pixi.lock
file is stored using Git LFS, make sure you have Git LFS installed and properly configured on your system before installation. After cloning the repository, run:
git lfs install && git lfs pull
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can install the default dependencies of the project using pixi
as follows:
pixi install
See pixi task list
for a list of available tasks.
The JaxSim API documentation is available at jaxsim.readthedocs.io.
Jaxsim can also be used as a multi-body dynamic library! With full support for automatic differentiation of RBDAs (forwards and reverse mode) and automatic differentiation against both kinematic and dynamic parameters.
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description, time_step=0.0001, is_urdf=True
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Initialize model data
data = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, 1.0]),
)
# Frame and dynamics computations
frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
# Frame transformation
W_H_F = js.frame.transform(
model=model, data=data, frame_index=frame_index
)
# Frame Jacobian
W_J_F = js.frame.jacobian(
model=model, data=data, frame_index=frame_index
)
# Dynamics properties
M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
# Print dynamics results
print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}")
The RBDAs are based on the theory of the Rigid Body Dynamics Algorithms book by Roy Featherstone. The algorithms and some simulation features were inspired by its accompanying code.
The development of JaxSim started in late 2021, inspired by early versions of google/brax
.
At that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates.
We are grateful to the Brax team for their work and for showing the potential of JAX in this field.
Brax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim. The development then shifted to MJX, which provides a JAX-based implementation of the Mujoco APIs.
The main differences between MJX/Brax and JaxSim are as follows:
- JaxSim supports out-of-the-box all SDF models with Pose Frame Semantics.
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
We welcome contributions from the community. Please read the contributing guide to get started.
@software{ferigo_jaxsim_2022,
author = {Diego Ferigo and Filippo Luca Ferretti and Silvio Traversaro and Daniele Pucci},
title = {{JaxSim}: A Differentiable Physics Engine and Multibody Dynamics Library for Control and Robot Learning},
url = {http://github.com/ami-iit/jaxsim},
year = {2022},
}
Theoretical aspects of JaxSim are based on Chapters 7 and 8 of the following Ph.D. thesis:
@phdthesis{ferigo_phd_thesis_2022,
title = {Simulation Architectures for Reinforcement Learning applied to Robotics},
author = {Diego Ferigo},
school = {University of Manchester},
type = {PhD Thesis},
month = {July},
year = {2022},
}
Authors | Maintainers |
---|---|