diff --git a/mujoco_playground/_src/locomotion/__init__.py b/mujoco_playground/_src/locomotion/__init__.py
index ce5d98287..a80c3d1c7 100644
--- a/mujoco_playground/_src/locomotion/__init__.py
+++ b/mujoco_playground/_src/locomotion/__init__.py
@@ -22,6 +22,7 @@
from mujoco import mjx
from mujoco_playground._src import mjx_env
+from mujoco_playground._src.locomotion.apollo import joystick as apollo_joystick
from mujoco_playground._src.locomotion.barkour import joystick as barkour_joystick
from mujoco_playground._src.locomotion.berkeley_humanoid import joystick as berkeley_humanoid_joystick
from mujoco_playground._src.locomotion.berkeley_humanoid import randomize as berkeley_humanoid_randomize
@@ -41,6 +42,9 @@
from mujoco_playground._src.locomotion.t1 import randomize as t1_randomize
_envs = {
+ "ApolloJoystickFlatTerrain": functools.partial(
+ apollo_joystick.Joystick, task="flat_terrain"
+ ),
"BarkourJoystick": barkour_joystick.Joystick,
"BerkeleyHumanoidJoystickFlatTerrain": functools.partial(
berkeley_humanoid_joystick.Joystick, task="flat_terrain"
@@ -82,6 +86,7 @@
}
_cfgs = {
+ "ApolloJoystickFlatTerrain": apollo_joystick.default_config,
"BarkourJoystick": barkour_joystick.default_config,
"BerkeleyHumanoidJoystickFlatTerrain": (
berkeley_humanoid_joystick.default_config
diff --git a/mujoco_playground/_src/locomotion/apollo/__init__.py b/mujoco_playground/_src/locomotion/apollo/__init__.py
new file mode 100644
index 000000000..8d9506aab
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/mujoco_playground/_src/locomotion/apollo/base.py b/mujoco_playground/_src/locomotion/apollo/base.py
new file mode 100644
index 000000000..28fec9b4e
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/base.py
@@ -0,0 +1,162 @@
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base classes for Apollo."""
+
+from typing import Any, Dict, Optional, Union
+
+import jax
+import jax.numpy as jp
+import mujoco
+import numpy as np
+from etils import epath
+from ml_collections import config_dict
+from mujoco import mjx
+
+from mujoco_playground._src import mjx_env
+from mujoco_playground._src.locomotion.apollo import constants as consts
+from mujoco_playground._src.collision import geoms_colliding
+
+
+def get_assets() -> Dict[str, bytes]:
+ assets = {}
+ # Playground assets.
+ mjx_env.update_assets(assets, consts.XML_DIR, "*.xml")
+ mjx_env.update_assets(assets, consts.XML_DIR / "assets")
+ # Menagerie assets.
+ path = mjx_env.MENAGERIE_PATH / "apptronik_apollo"
+ mjx_env.update_assets(assets, path, "*.xml")
+ mjx_env.update_assets(assets, path / "assets")
+ mjx_env.update_assets(assets, path / "assets" / "ability_hand")
+ return assets
+
+
+class ApolloEnv(mjx_env.MjxEnv):
+ """Base class for Apollo environments."""
+
+ def __init__(
+ self,
+ xml_path: str,
+ config: config_dict.ConfigDict,
+ config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
+ ) -> None:
+ super().__init__(config, config_overrides)
+
+ self._mj_model = mujoco.MjModel.from_xml_string(
+ epath.Path(xml_path).read_text(), assets=get_assets()
+ )
+ self._mj_model.opt.timestep = self.sim_dt
+
+ self._mj_model.vis.global_.offwidth = 3840
+ self._mj_model.vis.global_.offheight = 2160
+
+ self._mjx_model = mjx.put_model(self._mj_model)
+ self._xml_path = xml_path
+
+ self._init_q = jp.array(self._mj_model.keyframe("knees_bent").qpos)
+ self._default_ctrl = jp.array(self._mj_model.keyframe("knees_bent").ctrl)
+ self._default_pose = jp.array(self._mj_model.keyframe("knees_bent").qpos[7:])
+ self._actuator_torques = self.mj_model.jnt_actfrcrange[1:, 1]
+
+ # Body IDs.
+ self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id
+
+ # Geom IDs.
+ self._floor_geom_id = self._mj_model.geom("floor").id
+ self._left_feet_geom_id = np.array(
+ [self._mj_model.geom(name).id for name in consts.LEFT_FEET_GEOMS]
+ )
+ self._right_feet_geom_id = np.array(
+ [self._mj_model.geom(name).id for name in consts.RIGHT_FEET_GEOMS]
+ )
+ self._left_hand_geom_id = self._mj_model.geom("collision_l_hand_plate").id
+ self._right_hand_geom_id = self._mj_model.geom("collision_r_hand_plate").id
+ self._left_foot_geom_id = self._mj_model.geom("collision_l_sole").id
+ self._right_foot_geom_id = self._mj_model.geom("collision_r_sole").id
+ self._left_shin_geom_id = self._mj_model.geom("collision_capsule_body_l_shin").id
+ self._right_shin_geom_id = self._mj_model.geom("collision_capsule_body_r_shin").id
+ self._left_thigh_geom_id = self._mj_model.geom("collision_capsule_body_l_thigh").id
+ self._right_thigh_geom_id = self._mj_model.geom("collision_capsule_body_r_thigh").id
+
+ # Site IDs.
+ self._imu_site_id = self._mj_model.site("imu").id
+ self._feet_site_id = np.array(
+ [self._mj_model.site(name).id for name in consts.FEET_SITES]
+ )
+
+ # Sensor readings.
+
+ def get_gravity(self, data: mjx.Data) -> jax.Array:
+ """Return the gravity vector in the world frame."""
+ return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GRAVITY_SENSOR}")
+
+ def get_global_linvel(self, data: mjx.Data) -> jax.Array:
+ """Return the linear velocity of the robot in the world frame."""
+ return mjx_env.get_sensor_data(
+ self.mj_model, data, f"{consts.GLOBAL_LINVEL_SENSOR}"
+ )
+
+ def get_global_angvel(self, data: mjx.Data) -> jax.Array:
+ """Return the angular velocity of the robot in the world frame."""
+ return mjx_env.get_sensor_data(
+ self.mj_model, data, f"{consts.GLOBAL_ANGVEL_SENSOR}"
+ )
+
+ def get_local_linvel(self, data: mjx.Data) -> jax.Array:
+ """Return the linear velocity of the robot in the local frame."""
+ return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.LOCAL_LINVEL_SENSOR}")
+
+ def get_accelerometer(self, data: mjx.Data) -> jax.Array:
+ """Return the accelerometer readings in the local frame."""
+ return mjx_env.get_sensor_data(
+ self.mj_model, data, f"{consts.ACCELEROMETER_SENSOR}"
+ )
+
+ def get_gyro(self, data: mjx.Data) -> jax.Array:
+ """Return the gyroscope readings in the local frame."""
+ return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GYRO_SENSOR}")
+
+ def get_feet_ground_contacts(self, data: mjx.Data) -> jax.Array:
+ """Return an array indicating whether each foot is in contact with the ground."""
+ left_feet_contact = jp.array(
+ [
+ geoms_colliding(data, geom_id, self._floor_geom_id)
+ for geom_id in self._left_feet_geom_id
+ ]
+ )
+ right_feet_contact = jp.array(
+ [
+ geoms_colliding(data, geom_id, self._floor_geom_id)
+ for geom_id in self._right_feet_geom_id
+ ]
+ )
+ return jp.hstack([jp.any(left_feet_contact), jp.any(right_feet_contact)])
+
+ # Accessors.
+
+ @property
+ def xml_path(self) -> str:
+ return self._xml_path
+
+ @property
+ def action_size(self) -> int:
+ return self._mjx_model.nu
+
+ @property
+ def mj_model(self) -> mujoco.MjModel:
+ return self._mj_model
+
+ @property
+ def mjx_model(self) -> mjx.Model:
+ return self._mjx_model
diff --git a/mujoco_playground/_src/locomotion/apollo/constants.py b/mujoco_playground/_src/locomotion/apollo/constants.py
new file mode 100644
index 000000000..3b6b590cd
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/constants.py
@@ -0,0 +1,53 @@
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Constants for Apollo."""
+
+from etils import epath
+
+from mujoco_playground._src import mjx_env
+
+XML_DIR = mjx_env.ROOT_PATH / "locomotion" / "apollo" / "xmls"
+
+FEET_ONLY_FLAT_TERRAIN_XML = XML_DIR / "scene_mjx_feetonly_flat_terrain.xml"
+
+
+def task_to_xml(task_name: str) -> epath.Path:
+ return {
+ "flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML,
+ }[task_name]
+
+
+FEET_SITES = [
+ "l_foot",
+ "r_foot",
+]
+
+HAND_SITES = [
+ "left_palm",
+ "right_palm",
+]
+
+LEFT_FEET_GEOMS = ["collision_l_sole"]
+RIGHT_FEET_GEOMS = ["collision_r_sole"]
+FEET_GEOMS = LEFT_FEET_GEOMS + RIGHT_FEET_GEOMS
+
+ROOT_BODY = "torso_link"
+
+GRAVITY_SENSOR = "upvector"
+GLOBAL_LINVEL_SENSOR = "global_linvel"
+GLOBAL_ANGVEL_SENSOR = "global_angvel"
+LOCAL_LINVEL_SENSOR = "local_linvel"
+ACCELEROMETER_SENSOR = "accelerometer"
+GYRO_SENSOR = "gyro"
diff --git a/mujoco_playground/_src/locomotion/apollo/joystick.py b/mujoco_playground/_src/locomotion/apollo/joystick.py
new file mode 100644
index 000000000..6472cc453
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/joystick.py
@@ -0,0 +1,405 @@
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Joystick task for Apollo."""
+
+from typing import Any, Dict, Optional, Union
+
+import jax
+import jax.numpy as jp
+from ml_collections import config_dict
+from mujoco import mjx
+from mujoco.mjx._src import math
+
+from mujoco_playground._src import gait, mjx_env
+from mujoco_playground._src.locomotion.apollo import base
+from mujoco_playground._src.locomotion.apollo import constants as consts
+from mujoco_playground._src.collision import geoms_colliding
+
+
+def default_config() -> config_dict.ConfigDict:
+ return config_dict.create(
+ ctrl_dt=0.02,
+ sim_dt=0.005,
+ episode_length=1000,
+ action_repeat=1,
+ action_scale=0.5,
+ noise_config=config_dict.create(
+ level=1.0,
+ scales=config_dict.create(
+ joint_pos=0.03,
+ joint_vel=1.5,
+ gravity=0.05,
+ linvel=0.1,
+ gyro=0.2,
+ ),
+ ),
+ reward_config=config_dict.create(
+ scales=config_dict.create(
+ tracking=1.0,
+ lin_vel_z=0.0,
+ ang_vel_xy=-0.15,
+ orientation=-1.0,
+ torques=0.0,
+ action_rate=0.0,
+ energy=0.0,
+ feet_phase=1.0,
+ alive=0.0,
+ termination=0.0,
+ pose=-1.0,
+ collision=-1.0,
+ ),
+ tracking_sigma=0.25,
+ max_foot_height=0.12,
+ ),
+ push_config=config_dict.create(
+ enable=True,
+ interval_range=[5.0, 10.0],
+ magnitude_range=[0.1, 2.0],
+ ),
+ command_config=config_dict.create(
+ min=[-1.5, -0.8, -1.5],
+ max=[1.5, 0.8, 1.5],
+ zero_prob=[0.9, 0.25, 0.5],
+ ),
+ )
+
+
+class Joystick(base.ApolloEnv):
+ """Track a joystick command."""
+
+ def __init__(
+ self,
+ task: str = "flat_terrain",
+ config: config_dict.ConfigDict = default_config(),
+ config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
+ ):
+ super().__init__(
+ xml_path=consts.task_to_xml(task).as_posix(),
+ config=config,
+ config_overrides=config_overrides,
+ )
+
+ self._cmd_min = jp.array(self._config.command_config.min)
+ self._cmd_max = jp.array(self._config.command_config.max)
+ self._cmd_zero_prob = jp.array(self._config.command_config.zero_prob)
+
+ # fmt: off
+ self._weights = jp.array([
+ 5.0, 5.0, 5.0, # Torso.
+ 1.0, 1.0, 1.0, # Neck.
+ 1.0, 1.0, 0.1, 1.0, 1.0, 1.0, 1.0, # Left arm.
+ 1.0, 1.0, 0.1, 1.0, 1.0, 1.0, 1.0, # Right arm.
+ 1.0, 1.0, 0.01, 0.01, 1.0, 1.0, # Left leg.
+ 1.0, 1.0, 0.01, 0.01, 1.0, 1.0, # Right leg.
+ ])
+ # fmt: on
+
+ def reset(self, rng: jax.Array) -> mjx_env.State:
+ qpos = self._init_q
+ qvel = jp.zeros(self.mjx_model.nv)
+
+ # Randomize xy position and yaw, xy=+U(-0.5, 0.5), yaw=U(-pi, pi).
+ rng, key = jax.random.split(rng)
+ dxy = jax.random.uniform(key, (2,), minval=-0.5, maxval=0.5)
+ qpos = qpos.at[0:2].set(qpos[0:2] + dxy)
+ rng, key = jax.random.split(rng)
+ yaw = jax.random.uniform(key, (1,), minval=-3.14, maxval=3.14)
+ quat = math.axis_angle_to_quat(jp.array([0, 0, 1]), yaw)
+ new_quat = math.quat_mul(qpos[3:7], quat)
+ qpos = qpos.at[3:7].set(new_quat)
+
+ # Perturb initial joint angles, qpos[7:]=*U(0.5, 1.5)
+ rng, key = jax.random.split(rng)
+ qpos = qpos.at[7:].set(
+ qpos[7:]
+ * jax.random.uniform(key, (self.mjx_model.nq - 7,), minval=0.5, maxval=1.5)
+ )
+
+ # Perturb initial joint velocities, d(xyzrpy)=U(-0.5, 0.5)
+ rng, key = jax.random.split(rng)
+ qvel = qvel.at[0:6].set(jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5))
+
+ data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel)
+
+ # Sample gait frequency =U(1.25, 1.75).
+ rng, key = jax.random.split(rng)
+ gait_freq = jax.random.uniform(key, (1,), minval=1.25, maxval=1.75)
+ phase_dt = 2 * jp.pi * self.dt * gait_freq
+ phase = jp.array([0, jp.pi])
+
+ # Sample push interval.
+ rng, push_rng = jax.random.split(rng)
+ push_interval = jax.random.uniform(
+ push_rng,
+ minval=self._config.push_config.interval_range[0],
+ maxval=self._config.push_config.interval_range[1],
+ )
+ push_interval_steps = jp.round(push_interval / self.dt).astype(jp.int32)
+
+ # Sample command.
+ rng, key1, key2 = jax.random.split(rng, 3)
+ time_until_next_cmd = jax.random.exponential(key1) * 5.0
+ steps_until_next_cmd = jp.round(time_until_next_cmd / self.dt).astype(jp.int32)
+ cmd = jax.random.uniform(
+ key2, shape=(3,), minval=self._cmd_min, maxval=self._cmd_max
+ )
+
+ info = {
+ "rng": rng,
+ "step": 0,
+ "command": cmd,
+ "steps_until_next_cmd": steps_until_next_cmd,
+ "last_act": jp.zeros(self.mjx_model.nu),
+ "phase_dt": phase_dt,
+ "phase": phase,
+ "push": jp.array([0.0, 0.0]),
+ "push_step": 0,
+ "push_interval_steps": push_interval_steps,
+ "filtered_linvel": jp.zeros(3),
+ "filtered_angvel": jp.zeros(3),
+ }
+ metrics = {
+ "termination/fall_termination": jp.zeros(()),
+ }
+ for k in self._config.reward_config.scales.keys():
+ metrics[f"reward/{k}"] = jp.zeros(())
+
+ obs = self._get_obs(data, info)
+ reward, done = jp.zeros(2)
+ return mjx_env.State(data, obs, reward, done, metrics, info)
+
+ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
+ state = self.apply_push(state)
+ motor_targets = self._default_ctrl + action * self._config.action_scale
+ data = mjx_env.step(self.mjx_model, state.data, motor_targets, self.n_substeps)
+
+ linvel = self.get_local_linvel(data)
+ state.info["filtered_linvel"] = linvel * 1.0 + state.info["filtered_linvel"] * 0.0
+ angvel = self.get_gyro(data)
+ state.info["filtered_angvel"] = angvel * 1.0 + state.info["filtered_angvel"] * 0.0
+
+ obs = self._get_obs(data, state.info)
+ done = self._get_termination(data, state.metrics)
+ rewards = self._get_reward(data, action, state.info, state.metrics, done)
+ rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}
+ reward = sum(rewards.values()) * self.dt
+
+ state.info["step"] += 1
+ phase_tp1 = state.info["phase"] + state.info["phase_dt"]
+ state.info["phase"] = jp.fmod(phase_tp1 + jp.pi, 2 * jp.pi) - jp.pi
+ state.info["phase"] = jp.where(
+ jp.linalg.norm(state.info["command"]) > 0.01,
+ state.info["phase"],
+ jp.ones(2) * jp.pi,
+ )
+ state.info["last_act"] = action
+ state.info["steps_until_next_cmd"] -= 1
+ state.info["rng"], key1, key2 = jax.random.split(state.info["rng"], 3)
+ state.info["command"] = jp.where(
+ state.info["steps_until_next_cmd"] <= 0,
+ self.sample_command(key1, state.info["command"]),
+ state.info["command"],
+ )
+ state.info["steps_until_next_cmd"] = jp.where(
+ done | (state.info["steps_until_next_cmd"] <= 0),
+ jp.round(jax.random.exponential(key2) * 5.0 / self.dt).astype(jp.int32),
+ state.info["steps_until_next_cmd"],
+ )
+ for k, v in rewards.items():
+ state.metrics[f"reward/{k}"] = v
+ done = done.astype(reward.dtype)
+ state = state.replace(data=data, obs=obs, reward=reward, done=done)
+ return state
+
+ def _get_termination(self, data: mjx.Data, metrics: dict[str, Any]) -> jax.Array:
+ fall_termination = self.get_gravity(data)[-1] < 0.0
+ metrics["termination/fall_termination"] = fall_termination.astype(jp.float32)
+ return fall_termination | jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any()
+
+ def _apply_noise(
+ self, info: dict[str, Any], value: jax.Array, scale: float
+ ) -> jax.Array:
+ info["rng"], noise_rng = jax.random.split(info["rng"])
+ noise = 2 * jax.random.uniform(noise_rng, shape=value.shape) - 1
+ noisy_value = value + noise * self._config.noise_config.level * scale
+ return noisy_value
+
+ def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> mjx_env.Observation:
+ # Ground-truth observations.
+ gyro = self.get_gyro(data)
+ gravity = data.site_xmat[self._imu_site_id].T @ jp.array([0, 0, -1])
+ joint_angles = data.qpos[7:]
+ joint_vel = data.qvel[6:]
+ linvel = self.get_local_linvel(data)
+ phase = jp.concatenate([jp.cos(info["phase"]), jp.sin(info["phase"])])
+ root_pos = data.qpos[:3]
+ root_quat = data.qpos[3:7]
+ actuator_torques = data.actuator_force
+ # Noisy observations.
+ noise_scales = self._config.noise_config.scales
+ noisy_gyro = self._apply_noise(info, gyro, noise_scales.gyro)
+ noisy_gravity = self._apply_noise(info, gravity, noise_scales.gravity)
+ noisy_joint_angles = self._apply_noise(info, joint_angles, noise_scales.joint_pos)
+ noisy_joint_vel = self._apply_noise(info, joint_vel, noise_scales.joint_vel)
+ noisy_linvel = self._apply_noise(info, linvel, noise_scales.linvel)
+ state = jp.hstack(
+ [
+ noisy_linvel,
+ noisy_gyro,
+ noisy_gravity,
+ info["command"],
+ noisy_joint_angles - self._init_q[7:],
+ noisy_joint_vel,
+ info["last_act"],
+ phase,
+ ]
+ )
+ privileged_state = jp.hstack(
+ [
+ state,
+ # Unnoised.
+ gyro,
+ gravity,
+ linvel,
+ joint_angles - self._init_q[7:],
+ joint_vel,
+ # Extra.
+ actuator_torques,
+ root_pos,
+ root_quat,
+ ]
+ )
+ return {
+ "state": state,
+ "privileged_state": privileged_state,
+ }
+
+ def _get_reward(
+ self,
+ data: mjx.Data,
+ action: jax.Array,
+ info: dict[str, Any],
+ metrics: dict[str, Any],
+ done: jax.Array,
+ ) -> dict[str, jax.Array]:
+ del metrics # Unused.
+ return {
+ "termination": done,
+ "alive": jp.array(1.0) - done,
+ "tracking": self._reward_tracking(info["command"], info),
+ "lin_vel_z": self._cost_lin_vel_z(info["filtered_linvel"]),
+ "ang_vel_xy": self._cost_ang_vel_xy(info["filtered_angvel"]),
+ "orientation": self._cost_orientation(self.get_gravity(data)),
+ "feet_phase": self._reward_feet_phase(data, info["phase"]),
+ "torques": self._cost_torques(data.actuator_force),
+ "action_rate": self._cost_action_rate(action, info["last_act"]),
+ "energy": self._cost_energy(data.qvel, data.actuator_force),
+ "collision": self._cost_collision(data),
+ "pose": self._cost_pose(data.qpos, info["command"]),
+ }
+
+ def _reward_tracking(self, commands: jax.Array, info: dict[str, Any]) -> jax.Array:
+ lin_vel_error = jp.sum(jp.square(commands[:2] - info["filtered_linvel"][:2]))
+ r_linvel = jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma)
+ ang_vel_error = jp.square(commands[2] - info["filtered_angvel"][2])
+ r_angvel = jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma)
+ return r_linvel + 0.5 * r_angvel
+
+ def _cost_lin_vel_z(self, local_linvel) -> jax.Array:
+ return jp.square(local_linvel[2])
+
+ def _cost_ang_vel_xy(self, local_angvel) -> jax.Array:
+ return jp.sum(jp.square(local_angvel[:2]))
+
+ def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array:
+ return jp.sum(jp.square(torso_zaxis[:2]))
+
+ def _cost_torques(self, torques: jax.Array) -> jax.Array:
+ return jp.sum(jp.abs(torques))
+
+ def _cost_energy(self, qvel: jax.Array, qfrc_actuator: jax.Array) -> jax.Array:
+ torques = qfrc_actuator / self._actuator_torques
+ return jp.sum(jp.abs(qvel[6:] * torques))
+
+ def _cost_action_rate(self, act: jax.Array, last_act: jax.Array) -> jax.Array:
+ return jp.sum(jp.square(act - last_act))
+
+ def _cost_collision(self, data: mjx.Data) -> jax.Array:
+ # Hand - thigh.
+ c = geoms_colliding(data, self._left_hand_geom_id, self._left_thigh_geom_id)
+ c |= geoms_colliding(data, self._right_hand_geom_id, self._right_thigh_geom_id)
+ # Foot - foot.
+ c |= geoms_colliding(data, self._left_foot_geom_id, self._right_foot_geom_id)
+ # Shin - shin.
+ c |= geoms_colliding(
+ data,
+ self._left_shin_geom_id,
+ self._right_shin_geom_id,
+ )
+ # Thigh - thigh.
+ c |= geoms_colliding(
+ data,
+ self._left_thigh_geom_id,
+ self._right_thigh_geom_id,
+ )
+ return jp.any(c)
+
+ def _cost_pose(self, qpos: jax.Array, commands: jax.Array) -> jax.Array:
+ # Uniform weights when standing still.
+ weights = jp.where(
+ jp.linalg.norm(commands) < 0.01,
+ jp.ones_like(self._weights),
+ self._weights,
+ )
+ # Reduce hip roll weight when lateral command is high.
+ lateral_cmd = jp.abs(commands[1])
+ hip_roll_weight = jp.where(lateral_cmd > 0.3, 0.01, 1.0)
+ weights = weights.at[21].set(hip_roll_weight)
+ weights = weights.at[27].set(hip_roll_weight)
+ return jp.sum(jp.square(qpos[7:] - self._init_q[7:]) * weights)
+
+ def _reward_feet_phase(self, data: mjx.Data, phase: jax.Array) -> jax.Array:
+ foot_z = data.site_xpos[self._feet_site_id][..., -1]
+ rz = gait.get_rz(phase, swing_height=self._config.reward_config.max_foot_height)
+ error = jp.sum(jp.square(foot_z - rz))
+ return jp.exp(-error / 0.01)
+
+ def sample_command(self, rng: jax.Array, x_k: jax.Array) -> jax.Array:
+ rng, y_rng, w_rng, z_rng = jax.random.split(rng, 4)
+ y_k = jax.random.uniform(
+ y_rng, shape=(3,), minval=self._cmd_min, maxval=self._cmd_max
+ )
+ z_k = jax.random.bernoulli(z_rng, self._cmd_zero_prob, shape=(3,))
+ w_k = jax.random.bernoulli(w_rng, 0.5, shape=(3,))
+ return x_k - w_k * (x_k - y_k * z_k)
+
+ def apply_push(self, state: mjx_env.State) -> mjx_env.State:
+ state.info["rng"], push1_rng, push2_rng = jax.random.split(state.info["rng"], 3)
+ push_theta = jax.random.uniform(push1_rng, maxval=2 * jp.pi)
+ push_magnitude = jax.random.uniform(
+ push2_rng,
+ minval=self._config.push_config.magnitude_range[0],
+ maxval=self._config.push_config.magnitude_range[1],
+ )
+ push = jp.array([jp.cos(push_theta), jp.sin(push_theta)])
+ push *= jp.mod(state.info["push_step"] + 1, state.info["push_interval_steps"]) == 0
+ push *= self._config.push_config.enable
+ state.info["push"] = push
+ state.info["push_step"] += 1
+ qvel = state.data.qvel
+ qvel = qvel.at[:2].set(push * push_magnitude + qvel[:2])
+ data = state.data.replace(qvel=qvel)
+ state = state.replace(data=data)
+ return state
diff --git a/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml b/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml
new file mode 100644
index 000000000..7ef7d76c0
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml
@@ -0,0 +1,538 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml b/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml
new file mode 100644
index 000000000..2ad479f09
--- /dev/null
+++ b/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml
@@ -0,0 +1,60 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco_playground/config/locomotion_params.py b/mujoco_playground/config/locomotion_params.py
index eaa7587d8..cc777a4e1 100644
--- a/mujoco_playground/config/locomotion_params.py
+++ b/mujoco_playground/config/locomotion_params.py
@@ -134,6 +134,19 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
value_obs_key="privileged_state",
)
+ elif env_name in ("ApolloJoystickFlatTerrain",):
+ rl_config.num_timesteps = 200_000_000
+ rl_config.num_evals = 20
+ rl_config.clipping_epsilon = 0.2
+ rl_config.num_resets_per_eval = 1
+ rl_config.entropy_cost = 0.005
+ rl_config.network_factory = config_dict.create(
+ policy_hidden_layer_sizes=(512, 256, 128),
+ value_hidden_layer_sizes=(512, 256, 128),
+ policy_obs_key="state",
+ value_obs_key="privileged_state",
+ )
+
elif env_name in (
"BarkourJoystick",
"H1InplaceGaitTracking",
diff --git a/mujoco_playground/experimental/learning/apollo_joystick.ipynb b/mujoco_playground/experimental/learning/apollo_joystick.ipynb
new file mode 100644
index 000000000..a6a405f61
--- /dev/null
+++ b/mujoco_playground/experimental/learning/apollo_joystick.ipynb
@@ -0,0 +1,301 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
+ "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
+ "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
+ "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
+ "os.environ[\"MUJOCO_GL\"] = \"egl\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import functools\n",
+ "import json\n",
+ "import pickle\n",
+ "from datetime import datetime\n",
+ "\n",
+ "import jax\n",
+ "import mediapy as media\n",
+ "import mujoco\n",
+ "import numpy as np\n",
+ "from brax.training.agents.ppo import networks as ppo_networks\n",
+ "from brax.training.agents.ppo import train as ppo\n",
+ "from etils import epath\n",
+ "from flax.training import orbax_utils\n",
+ "from orbax import checkpoint as ocp\n",
+ "\n",
+ "from mujoco_playground import registry, wrapper\n",
+ "from mujoco_playground.config import locomotion_params\n",
+ "from mujoco_playground.experimental.utils.plotting import TrainingPlotter\n",
+ "\n",
+ "# Enable persistent compilation cache.\n",
+ "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
+ "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
+ "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env_name = \"ApolloJoystickFlatTerrain\"\n",
+ "env_cfg = registry.get_default_config(env_name)\n",
+ "randomizer = registry.get_domain_randomizer(env_name)\n",
+ "ppo_params = locomotion_params.brax_ppo_config(env_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env_cfg.reward_config.scales.energy = -1e-5\n",
+ "env_cfg.reward_config.scales.action_rate = -1e-3\n",
+ "env_cfg.reward_config.scales.torques = 0.0\n",
+ "\n",
+ "env_cfg.noise_config.level = 0.0 # 1.0\n",
+ "env_cfg.push_config.enable = True\n",
+ "env_cfg.push_config.magnitude_range = [0.1, 2.0]\n",
+ "\n",
+ "ppo_params.num_timesteps = 150_000_000\n",
+ "ppo_params.num_evals = 15"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SUFFIX = None\n",
+ "FINETUNE_PATH = None\n",
+ "\n",
+ "# Generate unique experiment name.\n",
+ "now = datetime.now()\n",
+ "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n",
+ "exp_name = f\"{env_name}-{timestamp}\"\n",
+ "if SUFFIX is not None:\n",
+ " exp_name += f\"-{SUFFIX}\"\n",
+ "print(f\"{exp_name}\")\n",
+ "\n",
+ "# Possibly restore from the latest checkpoint.\n",
+ "if FINETUNE_PATH is not None:\n",
+ " FINETUNE_PATH = epath.Path(FINETUNE_PATH)\n",
+ " latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
+ " latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n",
+ " latest_ckpts.sort(key=lambda x: int(x.name))\n",
+ " latest_ckpt = latest_ckpts[-1]\n",
+ " restore_checkpoint_path = latest_ckpt\n",
+ " print(f\"Restoring from: {restore_checkpoint_path}\")\n",
+ "else:\n",
+ " restore_checkpoint_path = None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n",
+ "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
+ "print(f\"{ckpt_path}\")\n",
+ "\n",
+ "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
+ " json.dump(env_cfg.to_json(), fp, indent=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plotter = TrainingPlotter(max_timesteps=ppo_params.num_timesteps, figsize=(15, 10))\n",
+ "\n",
+ "\n",
+ "def progress(num_steps, metrics):\n",
+ " plotter.update(num_steps, metrics)\n",
+ "\n",
+ "\n",
+ "def policy_params_fn(current_step, make_policy, params):\n",
+ " del make_policy # Unused.\n",
+ " orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
+ " save_args = orbax_utils.save_args_from_target(params)\n",
+ " path = ckpt_path / f\"{current_step}\"\n",
+ " orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
+ "\n",
+ "\n",
+ "training_params = dict(ppo_params)\n",
+ "del training_params[\"network_factory\"]\n",
+ "\n",
+ "train_fn = functools.partial(\n",
+ " ppo.train,\n",
+ " **training_params,\n",
+ " network_factory=functools.partial(\n",
+ " ppo_networks.make_ppo_networks, **ppo_params.network_factory\n",
+ " ),\n",
+ " restore_checkpoint_path=restore_checkpoint_path,\n",
+ " progress_fn=progress,\n",
+ " wrap_env_fn=wrapper.wrap_for_brax_training,\n",
+ " policy_params_fn=policy_params_fn,\n",
+ " randomization_fn=randomizer,\n",
+ ")\n",
+ "\n",
+ "env = registry.load(env_name, config=env_cfg)\n",
+ "eval_env = registry.load(env_name, config=env_cfg)\n",
+ "make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)\n",
+ "if len(plotter.times) > 1:\n",
+ " print(f\"time to jit: {plotter.times[1] - plotter.times[0]}\")\n",
+ " print(f\"time to train: {plotter.times[-1] - plotter.times[1]}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "inference_fn = make_inference_fn(params, deterministic=True)\n",
+ "jit_inference_fn = jax.jit(inference_fn)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save normalizer and policy params to the checkpoint dir.\n",
+ "normalizer_params, policy_params, value_params = params\n",
+ "with open(ckpt_path / \"params.pkl\", \"wb\") as f:\n",
+ " data = {\n",
+ " \"normalizer_params\": normalizer_params,\n",
+ " \"policy_params\": policy_params,\n",
+ " \"value_params\": value_params,\n",
+ " }\n",
+ " pickle.dump(data, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mujoco_playground._src.gait import draw_joystick_command\n",
+ "\n",
+ "eval_env = registry.load(env_name, config=env_cfg)\n",
+ "jit_reset = jax.jit(eval_env.reset)\n",
+ "jit_step = jax.jit(eval_env.step)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rng = jax.random.PRNGKey(12345)\n",
+ "rollout = []\n",
+ "modify_scene_fns = []\n",
+ "state = jit_reset(rng)\n",
+ "for i in range(env_cfg.episode_length):\n",
+ " act_rng, rng = jax.random.split(rng)\n",
+ " ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
+ " state = jit_step(state, ctrl)\n",
+ " if state.done:\n",
+ " print(\"something bad happened\")\n",
+ " break\n",
+ " rollout.append(state)\n",
+ " xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso_link\").id])\n",
+ " xyz += np.array([0, 0.0, 0])\n",
+ " x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n",
+ " yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
+ " modify_scene_fns.append(\n",
+ " functools.partial(\n",
+ " draw_joystick_command,\n",
+ " cmd=state.info[\"command\"],\n",
+ " xyz=xyz,\n",
+ " theta=yaw,\n",
+ " scl=np.linalg.norm(state.info[\"command\"]),\n",
+ " )\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "render_every = 2\n",
+ "fps = 1.0 / eval_env.dt / render_every\n",
+ "print(f\"fps: {fps}\")\n",
+ "traj = rollout[::render_every]\n",
+ "mod_fns = modify_scene_fns[::render_every]\n",
+ "\n",
+ "scene_option = mujoco.MjvOption()\n",
+ "scene_option.geomgroup[2] = True\n",
+ "scene_option.geomgroup[3] = False\n",
+ "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
+ "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
+ "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
+ "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False\n",
+ "\n",
+ "frames = eval_env.render(\n",
+ " traj,\n",
+ " camera=\"track\",\n",
+ " scene_option=scene_option,\n",
+ " width=640,\n",
+ " height=480,\n",
+ " modify_scene_fns=mod_fns,\n",
+ ")\n",
+ "media.show_video(frames, fps=fps, loop=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx b/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx
new file mode 100644
index 000000000..604e6e78b
Binary files /dev/null and b/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx differ
diff --git a/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py b/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py
new file mode 100644
index 000000000..d879ead9f
--- /dev/null
+++ b/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py
@@ -0,0 +1,142 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deploy an MJX policy in ONNX format to C MuJoCo and play with it."""
+
+import mujoco
+import numpy as np
+import onnxruntime as rt
+from etils import epath
+from mujoco import viewer
+
+from mujoco_playground._src.locomotion.apollo import constants as apollo_constants
+from mujoco_playground._src.locomotion.apollo.base import get_assets
+from mujoco_playground.experimental.sim2sim.gamepad_reader import Gamepad
+
+_HERE = epath.Path(__file__).parent
+_ONNX_DIR = _HERE / "onnx"
+
+
+class OnnxController:
+ """ONNX controller for the Booster Apollo humanoid."""
+
+ def __init__(
+ self,
+ policy_path: str,
+ default_angles: np.ndarray,
+ ctrl_dt: float,
+ n_substeps: int,
+ action_scale: float = 0.5,
+ vel_scale_x: float = 1.0,
+ vel_scale_y: float = 1.0,
+ vel_scale_rot: float = 1.0,
+ ):
+ self._output_names = ["continuous_actions"]
+ self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"])
+
+ self._action_scale = action_scale
+ self._default_angles = default_angles
+ self._last_action = np.zeros_like(default_angles, dtype=np.float32)
+
+ self._counter = 0
+ self._n_substeps = n_substeps
+ self._ctrl_dt = ctrl_dt
+
+ self._phase = np.array([0.0, np.pi])
+ self._base_phase_dt = 2 * np.pi * ctrl_dt # Store base phase_dt without frequency
+
+ self._joystick = Gamepad(
+ vel_scale_x=vel_scale_x,
+ vel_scale_y=vel_scale_y,
+ vel_scale_rot=vel_scale_rot,
+ deadzone=0.03,
+ )
+
+ def get_obs(self, model, data) -> np.ndarray:
+ linvel = data.sensor("local_linvel").data
+ gyro = data.sensor("gyro").data
+ imu_xmat = data.site_xmat[model.site("imu").id].reshape(3, 3)
+ gravity = imu_xmat.T @ np.array([0, 0, -1])
+ joint_angles = data.qpos[7:] - self._default_angles
+ joint_velocities = data.qvel[6:]
+ command = self._joystick.get_command()
+ ph = self._phase if np.linalg.norm(command) >= 0.01 else np.ones(2) * np.pi
+ phase = np.concatenate([np.cos(ph), np.sin(ph)])
+ obs = np.hstack(
+ [
+ linvel,
+ gyro,
+ gravity,
+ command,
+ joint_angles,
+ joint_velocities,
+ self._last_action,
+ phase,
+ ]
+ )
+ return obs.astype(np.float32)
+
+ def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
+ self._counter += 1
+ if self._counter % self._n_substeps == 0:
+ obs = self.get_obs(model, data)
+ onnx_input = {"obs": obs.reshape(1, -1)}
+ onnx_pred = self._policy.run(self._output_names, onnx_input)[0][0]
+ self._last_action = onnx_pred.copy()
+ data.ctrl[:] = onnx_pred * self._action_scale + self._default_angles
+ command = self._joystick.get_command()
+ cmd_magnitude = np.linalg.norm(command)
+ if cmd_magnitude < 0.01:
+ gait_freq = 1.25
+ else:
+ gait_freq = 1.25 + 0.5 * min(cmd_magnitude, 1.5) / 1.5
+ phase_dt = self._base_phase_dt * gait_freq
+ phase_tp1 = self._phase + phase_dt
+ self._phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
+
+
+def load_callback(model=None, data=None):
+ mujoco.set_mjcb_control(None)
+
+ model = mujoco.MjModel.from_xml_path(
+ apollo_constants.FEET_ONLY_FLAT_TERRAIN_XML.as_posix(),
+ assets=get_assets(),
+ )
+ data = mujoco.MjData(model)
+
+ mujoco.mj_resetDataKeyframe(model, data, 0)
+
+ ctrl_dt = 0.02
+ sim_dt = 0.005
+ n_substeps = int(round(ctrl_dt / sim_dt))
+ model.opt.timestep = sim_dt
+
+ policy = OnnxController(
+ policy_path=(_ONNX_DIR / "apollo_policy.onnx").as_posix(),
+ default_angles=np.array(model.keyframe("knees_bent").qpos[7:]),
+ ctrl_dt=ctrl_dt,
+ n_substeps=n_substeps,
+ action_scale=0.5,
+ vel_scale_x=1.5,
+ vel_scale_y=0.8,
+ vel_scale_rot=1.5,
+ )
+
+ mujoco.set_mjcb_control(policy.get_control)
+
+ return model, data
+
+
+if __name__ == "__main__":
+ viewer.launch(loader=load_callback)
diff --git a/mujoco_playground/experimental/utils/plotting.py b/mujoco_playground/experimental/utils/plotting.py
new file mode 100644
index 000000000..be7809abb
--- /dev/null
+++ b/mujoco_playground/experimental/utils/plotting.py
@@ -0,0 +1,337 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from datetime import datetime
+from typing import Any, Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from IPython.display import clear_output, display
+
+
+class TrainingPlotter:
+ def __init__(
+ self,
+ max_timesteps: int = 50_000_000,
+ figsize: Tuple[int, int] = (12, 8),
+ max_cols: int = 3,
+ ):
+ self.max_timesteps = max_timesteps
+ self.max_cols = max_cols
+
+ # Default main metrics that we always want to plot.
+ self.default_metrics = [
+ "eval/episode_reward",
+ "eval/avg_episode_length",
+ "steps_per_second",
+ ]
+
+ self.metrics = []
+ self.metrics_std = []
+ self.metric_labels = []
+ self.error_metrics = []
+ self.error_metrics_std = []
+ self.error_metric_labels = []
+ self.reward_detail_metrics = []
+ self.reward_detail_metrics_std = []
+ self.reward_detail_metric_labels = []
+ self.termination_metrics = []
+ self.termination_metrics_std = []
+ self.termination_metric_labels = []
+
+ self.x_data = []
+ self.times = [datetime.now()]
+ self.metrics_data = {}
+ self.metrics_std_data = {}
+
+ self.fps_data = [] # Store calculated steps per second
+
+ # Use default matplotlib style
+ plt.rcParams["axes.grid"] = False
+ plt.rcParams["axes.edgecolor"] = "#888888"
+ plt.rcParams["axes.linewidth"] = 0.8
+
+ # Create initial figure and axes - we'll resize this later
+ n_cols = min(self.max_cols, 1) # Start with at least 1 column, but respect max_cols
+ n_rows = 1
+ self.fig, self.axes = plt.subplots(n_rows, n_cols, figsize=figsize)
+ self.axes = np.array([[self.axes]])
+
+ # Set figure background to white for clean look
+ self.fig.patch.set_facecolor("white")
+
+ # Set up the layout with reasonable spacing
+ plt.tight_layout(pad=2.5, h_pad=1.5, w_pad=1.0)
+
+ def _get_label_from_metric(self, metric: str) -> str:
+ parts = metric.split("/")
+ if len(parts) > 1:
+ label = parts[-1]
+ else:
+ label = metric
+ if label == "episode_reward":
+ return "reward_per_episode"
+ elif label == "avg_episode_length":
+ return "episode_length"
+ elif label == "steps_per_second":
+ return "steps_per_second"
+ else:
+ return label
+
+ def _initialize_metrics(self, metrics: Dict[str, Any]) -> None:
+ """Initialize all metrics from the first metrics dictionary."""
+ # Start with default metrics
+ self.metrics = self.default_metrics.copy()
+ self.metrics_std = [f"{m}_std" for m in self.metrics]
+ self.metric_labels = [self._get_label_from_metric(m) for m in self.metrics]
+
+ # Initialize data storage for default metrics
+ for metric in self.metrics:
+ self.metrics_data[metric] = []
+ for metric_std in self.metrics_std:
+ self.metrics_std_data[metric_std] = []
+
+ # Find all reward detail metrics (eval/episode_reward/*)
+ reward_prefix = "eval/episode_reward/"
+ for key in metrics:
+ if (
+ key.startswith(reward_prefix)
+ and not key.endswith("_std")
+ and key != "eval/episode_reward"
+ ):
+ self.reward_detail_metrics.append(key)
+ self.reward_detail_metrics_std.append(f"{key}_std")
+ label = key[len(reward_prefix) :]
+ self.reward_detail_metric_labels.append(label) # Keep underscores
+
+ # Initialize data storage
+ self.metrics_data[key] = []
+ self.metrics_std_data[f"{key}_std"] = []
+
+ # Find all error metrics (eval/episode_error/*)
+ error_prefix = "eval/episode_error/"
+ for key in metrics:
+ if key.startswith(error_prefix) and not key.endswith("_std"):
+ self.error_metrics.append(key)
+ self.error_metrics_std.append(f"{key}_std")
+ label = key[len(error_prefix) :]
+ self.error_metric_labels.append(label) # Keep underscores
+
+ # Initialize data storage
+ self.metrics_data[key] = []
+ self.metrics_std_data[f"{key}_std"] = []
+
+ # Find all termination metrics (eval/episode_termination/*)
+ termination_prefix = "eval/episode_termination/"
+ for key in metrics:
+ if key.startswith(termination_prefix) and not key.endswith("_std"):
+ self.termination_metrics.append(key)
+ self.termination_metrics_std.append(f"{key}_std")
+ label = key[len(termination_prefix) :]
+ self.termination_metric_labels.append(label) # Keep underscores
+
+ # Initialize data storage
+ self.metrics_data[key] = []
+ self.metrics_std_data[f"{key}_std"] = []
+
+ def update(self, num_steps: int, metrics: Dict[str, float]) -> None:
+ self.x_data.append(num_steps)
+ current_time = datetime.now()
+ self.times.append(current_time)
+
+ # Calculate steps per second if we have at least two data points
+ if len(self.x_data) > 1:
+ time_diff = (current_time - self.times[-2]).total_seconds()
+ steps_diff = self.x_data[-1] - self.x_data[-2]
+ if time_diff > 0:
+ fps = steps_diff / time_diff
+ else:
+ fps = 0
+ self.fps_data.append(fps)
+ else:
+ self.fps_data.append(0) # First point has no previous data to compare
+
+ # Initialize metrics if this is the first update.
+ if len(self.x_data) == 1:
+ self._initialize_metrics(metrics)
+ # Add fps to metrics data structure
+ self.metrics_data["steps_per_second"] = []
+ self.metrics_std_data["steps_per_second_std"] = []
+
+ # Update all metrics data.
+ all_metrics = (
+ self.metrics
+ + self.reward_detail_metrics
+ + self.error_metrics
+ + self.termination_metrics
+ )
+ all_metrics_std = (
+ self.metrics_std
+ + self.reward_detail_metrics_std
+ + self.error_metrics_std
+ + self.termination_metrics_std
+ )
+
+ for metric in all_metrics:
+ if metric == "steps_per_second":
+ self.metrics_data[metric].append(self.fps_data[-1])
+ elif metric in metrics:
+ self.metrics_data[metric].append(metrics[metric])
+ else:
+ last_value = self.metrics_data[metric][-1] if self.metrics_data[metric] else 0
+ self.metrics_data[metric].append(last_value)
+
+ for metric_std in all_metrics_std:
+ if metric_std in metrics:
+ self.metrics_std_data[metric_std].append(metrics[metric_std])
+ else:
+ last_value = (
+ self.metrics_std_data[metric_std][-1]
+ if self.metrics_std_data[metric_std]
+ else 0
+ )
+ self.metrics_std_data[metric_std].append(last_value)
+
+ clear_output(wait=True)
+
+ # Combine all metrics for plotting.
+ all_metrics = (
+ self.metrics
+ + self.reward_detail_metrics
+ + self.error_metrics
+ + self.termination_metrics
+ )
+ all_metrics_std = (
+ self.metrics_std
+ + self.reward_detail_metrics_std
+ + self.error_metrics_std
+ + self.termination_metrics_std
+ )
+ all_labels = (
+ self.metric_labels
+ + self.reward_detail_metric_labels
+ + self.error_metric_labels
+ + self.termination_metric_labels
+ )
+
+ # Calculate grid dimensions using max_cols.
+ total_plots = len(all_metrics)
+ n_cols = min(self.max_cols, total_plots) # Use max_cols parameter.
+ n_rows = (total_plots + n_cols - 1) // n_cols # Ceiling division.
+
+ # Check if we need to resize the axes grid
+ if n_rows > self.axes.shape[0] or n_cols > self.axes.shape[1]:
+ plt.close(self.fig)
+ # Calculate a better figure size based on the number of plots and columns
+ width = max(12, n_cols * 3.5) # 3.5 inches per column
+ height = max(8, n_rows * 2.5) # 2.5 inches per row
+ self.fig, self.axes = plt.subplots(n_rows, n_cols, figsize=(width, height))
+
+ # Handle case where there's only one plot.
+ if n_rows == 1 and n_cols == 1:
+ self.axes = np.array([[self.axes]])
+ elif n_rows == 1:
+ self.axes = np.array([self.axes])
+ elif n_cols == 1:
+ self.axes = np.array([[ax] for ax in self.axes])
+
+ # Plot all metrics
+ self._plot_metrics(all_metrics, all_metrics_std, all_labels, self.axes)
+
+ # Add a single x-axis label at the bottom of the figure.
+ self.fig.text(
+ 0.5, 0.01, "# environment steps", ha="center", fontsize=12, fontweight="bold"
+ )
+
+ # Update layout and display.
+ self.fig.tight_layout(pad=2.5, h_pad=1.5, w_pad=1.0)
+ self.fig.subplots_adjust(bottom=0.08)
+ display(self.fig)
+
+ def _plot_metrics(
+ self,
+ metrics_list: List[str],
+ metrics_std_list: List[str],
+ labels_list: List[str],
+ axes_grid: np.ndarray,
+ ) -> None:
+ """Plot a set of metrics on the given axes grid."""
+ for i, (metric, metric_std, label) in enumerate(
+ zip(metrics_list, metrics_std_list, labels_list)
+ ):
+ row, col = i // axes_grid.shape[1], i % axes_grid.shape[1]
+ if row < axes_grid.shape[0] and col < axes_grid.shape[1]:
+ ax = axes_grid[row][col]
+ ax.clear()
+ ax.set_xlim([0, self.max_timesteps * 1.25])
+
+ # Remove x-axis labels from all subplots
+ ax.set_xlabel("")
+
+ # Make tick labels smaller to save space
+ ax.tick_params(axis="both", which="major", labelsize=9)
+
+ # Add subtle grid for better readability
+ ax.grid(True, linestyle="-", linewidth=0.5, alpha=0.2)
+
+ # Clean background
+ ax.set_facecolor("white")
+
+ # Format y-axis with fewer decimal places for cleaner look
+ ax.ticklabel_format(axis="y", style="plain", useOffset=False)
+
+ y_values = self.metrics_data[metric]
+ yerr_values = (
+ self.metrics_std_data[metric_std]
+ if metric_std in self.metrics_std_data
+ else None
+ )
+
+ if y_values:
+ # Add prefix based on metric type
+ prefix = ""
+ if "eval/episode_error/" in metric:
+ prefix = "error/"
+ elif "eval/episode_reward/" in metric:
+ prefix = "reward/"
+ elif "eval/episode_termination/" in metric:
+ prefix = "termination/"
+
+ # Use smaller font for title to save space
+ ax.set_title(
+ f"{prefix}{label}: {y_values[-1]:.3f}", fontsize=10, fontweight="bold"
+ )
+
+ # Plot the line with improved styling
+ line = ax.errorbar(
+ self.x_data,
+ y_values,
+ yerr=yerr_values,
+ color="black",
+ linewidth=1.5,
+ elinewidth=0.7,
+ capsize=2,
+ )
+
+ # Add very subtle shading under the curve for better visibility
+ ax.fill_between(self.x_data, 0, y_values, alpha=0.05, color="black")
+
+ # Hide unused subplots
+ for i in range(len(metrics_list), axes_grid.shape[0] * axes_grid.shape[1]):
+ row, col = i // axes_grid.shape[1], i % axes_grid.shape[1]
+ if row < axes_grid.shape[0] and col < axes_grid.shape[1]:
+ axes_grid[row][col].set_visible(False)
+
+ def save_figure(self, filename: str) -> None:
+ self.fig.savefig(filename, dpi=300, bbox_inches="tight")