diff --git a/mujoco_playground/_src/locomotion/__init__.py b/mujoco_playground/_src/locomotion/__init__.py index ce5d98287..a80c3d1c7 100644 --- a/mujoco_playground/_src/locomotion/__init__.py +++ b/mujoco_playground/_src/locomotion/__init__.py @@ -22,6 +22,7 @@ from mujoco import mjx from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.apollo import joystick as apollo_joystick from mujoco_playground._src.locomotion.barkour import joystick as barkour_joystick from mujoco_playground._src.locomotion.berkeley_humanoid import joystick as berkeley_humanoid_joystick from mujoco_playground._src.locomotion.berkeley_humanoid import randomize as berkeley_humanoid_randomize @@ -41,6 +42,9 @@ from mujoco_playground._src.locomotion.t1 import randomize as t1_randomize _envs = { + "ApolloJoystickFlatTerrain": functools.partial( + apollo_joystick.Joystick, task="flat_terrain" + ), "BarkourJoystick": barkour_joystick.Joystick, "BerkeleyHumanoidJoystickFlatTerrain": functools.partial( berkeley_humanoid_joystick.Joystick, task="flat_terrain" @@ -82,6 +86,7 @@ } _cfgs = { + "ApolloJoystickFlatTerrain": apollo_joystick.default_config, "BarkourJoystick": barkour_joystick.default_config, "BerkeleyHumanoidJoystickFlatTerrain": ( berkeley_humanoid_joystick.default_config diff --git a/mujoco_playground/_src/locomotion/apollo/__init__.py b/mujoco_playground/_src/locomotion/apollo/__init__.py new file mode 100644 index 000000000..8d9506aab --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/mujoco_playground/_src/locomotion/apollo/base.py b/mujoco_playground/_src/locomotion/apollo/base.py new file mode 100644 index 000000000..28fec9b4e --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/base.py @@ -0,0 +1,162 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base classes for Apollo.""" + +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +import mujoco +import numpy as np +from etils import epath +from ml_collections import config_dict +from mujoco import mjx + +from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.apollo import constants as consts +from mujoco_playground._src.collision import geoms_colliding + + +def get_assets() -> Dict[str, bytes]: + assets = {} + # Playground assets. + mjx_env.update_assets(assets, consts.XML_DIR, "*.xml") + mjx_env.update_assets(assets, consts.XML_DIR / "assets") + # Menagerie assets. + path = mjx_env.MENAGERIE_PATH / "apptronik_apollo" + mjx_env.update_assets(assets, path, "*.xml") + mjx_env.update_assets(assets, path / "assets") + mjx_env.update_assets(assets, path / "assets" / "ability_hand") + return assets + + +class ApolloEnv(mjx_env.MjxEnv): + """Base class for Apollo environments.""" + + def __init__( + self, + xml_path: str, + config: config_dict.ConfigDict, + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ) -> None: + super().__init__(config, config_overrides) + + self._mj_model = mujoco.MjModel.from_xml_string( + epath.Path(xml_path).read_text(), assets=get_assets() + ) + self._mj_model.opt.timestep = self.sim_dt + + self._mj_model.vis.global_.offwidth = 3840 + self._mj_model.vis.global_.offheight = 2160 + + self._mjx_model = mjx.put_model(self._mj_model) + self._xml_path = xml_path + + self._init_q = jp.array(self._mj_model.keyframe("knees_bent").qpos) + self._default_ctrl = jp.array(self._mj_model.keyframe("knees_bent").ctrl) + self._default_pose = jp.array(self._mj_model.keyframe("knees_bent").qpos[7:]) + self._actuator_torques = self.mj_model.jnt_actfrcrange[1:, 1] + + # Body IDs. + self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id + + # Geom IDs. + self._floor_geom_id = self._mj_model.geom("floor").id + self._left_feet_geom_id = np.array( + [self._mj_model.geom(name).id for name in consts.LEFT_FEET_GEOMS] + ) + self._right_feet_geom_id = np.array( + [self._mj_model.geom(name).id for name in consts.RIGHT_FEET_GEOMS] + ) + self._left_hand_geom_id = self._mj_model.geom("collision_l_hand_plate").id + self._right_hand_geom_id = self._mj_model.geom("collision_r_hand_plate").id + self._left_foot_geom_id = self._mj_model.geom("collision_l_sole").id + self._right_foot_geom_id = self._mj_model.geom("collision_r_sole").id + self._left_shin_geom_id = self._mj_model.geom("collision_capsule_body_l_shin").id + self._right_shin_geom_id = self._mj_model.geom("collision_capsule_body_r_shin").id + self._left_thigh_geom_id = self._mj_model.geom("collision_capsule_body_l_thigh").id + self._right_thigh_geom_id = self._mj_model.geom("collision_capsule_body_r_thigh").id + + # Site IDs. + self._imu_site_id = self._mj_model.site("imu").id + self._feet_site_id = np.array( + [self._mj_model.site(name).id for name in consts.FEET_SITES] + ) + + # Sensor readings. + + def get_gravity(self, data: mjx.Data) -> jax.Array: + """Return the gravity vector in the world frame.""" + return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GRAVITY_SENSOR}") + + def get_global_linvel(self, data: mjx.Data) -> jax.Array: + """Return the linear velocity of the robot in the world frame.""" + return mjx_env.get_sensor_data( + self.mj_model, data, f"{consts.GLOBAL_LINVEL_SENSOR}" + ) + + def get_global_angvel(self, data: mjx.Data) -> jax.Array: + """Return the angular velocity of the robot in the world frame.""" + return mjx_env.get_sensor_data( + self.mj_model, data, f"{consts.GLOBAL_ANGVEL_SENSOR}" + ) + + def get_local_linvel(self, data: mjx.Data) -> jax.Array: + """Return the linear velocity of the robot in the local frame.""" + return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.LOCAL_LINVEL_SENSOR}") + + def get_accelerometer(self, data: mjx.Data) -> jax.Array: + """Return the accelerometer readings in the local frame.""" + return mjx_env.get_sensor_data( + self.mj_model, data, f"{consts.ACCELEROMETER_SENSOR}" + ) + + def get_gyro(self, data: mjx.Data) -> jax.Array: + """Return the gyroscope readings in the local frame.""" + return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GYRO_SENSOR}") + + def get_feet_ground_contacts(self, data: mjx.Data) -> jax.Array: + """Return an array indicating whether each foot is in contact with the ground.""" + left_feet_contact = jp.array( + [ + geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._left_feet_geom_id + ] + ) + right_feet_contact = jp.array( + [ + geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._right_feet_geom_id + ] + ) + return jp.hstack([jp.any(left_feet_contact), jp.any(right_feet_contact)]) + + # Accessors. + + @property + def xml_path(self) -> str: + return self._xml_path + + @property + def action_size(self) -> int: + return self._mjx_model.nu + + @property + def mj_model(self) -> mujoco.MjModel: + return self._mj_model + + @property + def mjx_model(self) -> mjx.Model: + return self._mjx_model diff --git a/mujoco_playground/_src/locomotion/apollo/constants.py b/mujoco_playground/_src/locomotion/apollo/constants.py new file mode 100644 index 000000000..3b6b590cd --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/constants.py @@ -0,0 +1,53 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for Apollo.""" + +from etils import epath + +from mujoco_playground._src import mjx_env + +XML_DIR = mjx_env.ROOT_PATH / "locomotion" / "apollo" / "xmls" + +FEET_ONLY_FLAT_TERRAIN_XML = XML_DIR / "scene_mjx_feetonly_flat_terrain.xml" + + +def task_to_xml(task_name: str) -> epath.Path: + return { + "flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML, + }[task_name] + + +FEET_SITES = [ + "l_foot", + "r_foot", +] + +HAND_SITES = [ + "left_palm", + "right_palm", +] + +LEFT_FEET_GEOMS = ["collision_l_sole"] +RIGHT_FEET_GEOMS = ["collision_r_sole"] +FEET_GEOMS = LEFT_FEET_GEOMS + RIGHT_FEET_GEOMS + +ROOT_BODY = "torso_link" + +GRAVITY_SENSOR = "upvector" +GLOBAL_LINVEL_SENSOR = "global_linvel" +GLOBAL_ANGVEL_SENSOR = "global_angvel" +LOCAL_LINVEL_SENSOR = "local_linvel" +ACCELEROMETER_SENSOR = "accelerometer" +GYRO_SENSOR = "gyro" diff --git a/mujoco_playground/_src/locomotion/apollo/joystick.py b/mujoco_playground/_src/locomotion/apollo/joystick.py new file mode 100644 index 000000000..6472cc453 --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/joystick.py @@ -0,0 +1,405 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Joystick task for Apollo.""" + +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +from mujoco.mjx._src import math + +from mujoco_playground._src import gait, mjx_env +from mujoco_playground._src.locomotion.apollo import base +from mujoco_playground._src.locomotion.apollo import constants as consts +from mujoco_playground._src.collision import geoms_colliding + + +def default_config() -> config_dict.ConfigDict: + return config_dict.create( + ctrl_dt=0.02, + sim_dt=0.005, + episode_length=1000, + action_repeat=1, + action_scale=0.5, + noise_config=config_dict.create( + level=1.0, + scales=config_dict.create( + joint_pos=0.03, + joint_vel=1.5, + gravity=0.05, + linvel=0.1, + gyro=0.2, + ), + ), + reward_config=config_dict.create( + scales=config_dict.create( + tracking=1.0, + lin_vel_z=0.0, + ang_vel_xy=-0.15, + orientation=-1.0, + torques=0.0, + action_rate=0.0, + energy=0.0, + feet_phase=1.0, + alive=0.0, + termination=0.0, + pose=-1.0, + collision=-1.0, + ), + tracking_sigma=0.25, + max_foot_height=0.12, + ), + push_config=config_dict.create( + enable=True, + interval_range=[5.0, 10.0], + magnitude_range=[0.1, 2.0], + ), + command_config=config_dict.create( + min=[-1.5, -0.8, -1.5], + max=[1.5, 0.8, 1.5], + zero_prob=[0.9, 0.25, 0.5], + ), + ) + + +class Joystick(base.ApolloEnv): + """Track a joystick command.""" + + def __init__( + self, + task: str = "flat_terrain", + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__( + xml_path=consts.task_to_xml(task).as_posix(), + config=config, + config_overrides=config_overrides, + ) + + self._cmd_min = jp.array(self._config.command_config.min) + self._cmd_max = jp.array(self._config.command_config.max) + self._cmd_zero_prob = jp.array(self._config.command_config.zero_prob) + + # fmt: off + self._weights = jp.array([ + 5.0, 5.0, 5.0, # Torso. + 1.0, 1.0, 1.0, # Neck. + 1.0, 1.0, 0.1, 1.0, 1.0, 1.0, 1.0, # Left arm. + 1.0, 1.0, 0.1, 1.0, 1.0, 1.0, 1.0, # Right arm. + 1.0, 1.0, 0.01, 0.01, 1.0, 1.0, # Left leg. + 1.0, 1.0, 0.01, 0.01, 1.0, 1.0, # Right leg. + ]) + # fmt: on + + def reset(self, rng: jax.Array) -> mjx_env.State: + qpos = self._init_q + qvel = jp.zeros(self.mjx_model.nv) + + # Randomize xy position and yaw, xy=+U(-0.5, 0.5), yaw=U(-pi, pi). + rng, key = jax.random.split(rng) + dxy = jax.random.uniform(key, (2,), minval=-0.5, maxval=0.5) + qpos = qpos.at[0:2].set(qpos[0:2] + dxy) + rng, key = jax.random.split(rng) + yaw = jax.random.uniform(key, (1,), minval=-3.14, maxval=3.14) + quat = math.axis_angle_to_quat(jp.array([0, 0, 1]), yaw) + new_quat = math.quat_mul(qpos[3:7], quat) + qpos = qpos.at[3:7].set(new_quat) + + # Perturb initial joint angles, qpos[7:]=*U(0.5, 1.5) + rng, key = jax.random.split(rng) + qpos = qpos.at[7:].set( + qpos[7:] + * jax.random.uniform(key, (self.mjx_model.nq - 7,), minval=0.5, maxval=1.5) + ) + + # Perturb initial joint velocities, d(xyzrpy)=U(-0.5, 0.5) + rng, key = jax.random.split(rng) + qvel = qvel.at[0:6].set(jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5)) + + data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel) + + # Sample gait frequency =U(1.25, 1.75). + rng, key = jax.random.split(rng) + gait_freq = jax.random.uniform(key, (1,), minval=1.25, maxval=1.75) + phase_dt = 2 * jp.pi * self.dt * gait_freq + phase = jp.array([0, jp.pi]) + + # Sample push interval. + rng, push_rng = jax.random.split(rng) + push_interval = jax.random.uniform( + push_rng, + minval=self._config.push_config.interval_range[0], + maxval=self._config.push_config.interval_range[1], + ) + push_interval_steps = jp.round(push_interval / self.dt).astype(jp.int32) + + # Sample command. + rng, key1, key2 = jax.random.split(rng, 3) + time_until_next_cmd = jax.random.exponential(key1) * 5.0 + steps_until_next_cmd = jp.round(time_until_next_cmd / self.dt).astype(jp.int32) + cmd = jax.random.uniform( + key2, shape=(3,), minval=self._cmd_min, maxval=self._cmd_max + ) + + info = { + "rng": rng, + "step": 0, + "command": cmd, + "steps_until_next_cmd": steps_until_next_cmd, + "last_act": jp.zeros(self.mjx_model.nu), + "phase_dt": phase_dt, + "phase": phase, + "push": jp.array([0.0, 0.0]), + "push_step": 0, + "push_interval_steps": push_interval_steps, + "filtered_linvel": jp.zeros(3), + "filtered_angvel": jp.zeros(3), + } + metrics = { + "termination/fall_termination": jp.zeros(()), + } + for k in self._config.reward_config.scales.keys(): + metrics[f"reward/{k}"] = jp.zeros(()) + + obs = self._get_obs(data, info) + reward, done = jp.zeros(2) + return mjx_env.State(data, obs, reward, done, metrics, info) + + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: + state = self.apply_push(state) + motor_targets = self._default_ctrl + action * self._config.action_scale + data = mjx_env.step(self.mjx_model, state.data, motor_targets, self.n_substeps) + + linvel = self.get_local_linvel(data) + state.info["filtered_linvel"] = linvel * 1.0 + state.info["filtered_linvel"] * 0.0 + angvel = self.get_gyro(data) + state.info["filtered_angvel"] = angvel * 1.0 + state.info["filtered_angvel"] * 0.0 + + obs = self._get_obs(data, state.info) + done = self._get_termination(data, state.metrics) + rewards = self._get_reward(data, action, state.info, state.metrics, done) + rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()} + reward = sum(rewards.values()) * self.dt + + state.info["step"] += 1 + phase_tp1 = state.info["phase"] + state.info["phase_dt"] + state.info["phase"] = jp.fmod(phase_tp1 + jp.pi, 2 * jp.pi) - jp.pi + state.info["phase"] = jp.where( + jp.linalg.norm(state.info["command"]) > 0.01, + state.info["phase"], + jp.ones(2) * jp.pi, + ) + state.info["last_act"] = action + state.info["steps_until_next_cmd"] -= 1 + state.info["rng"], key1, key2 = jax.random.split(state.info["rng"], 3) + state.info["command"] = jp.where( + state.info["steps_until_next_cmd"] <= 0, + self.sample_command(key1, state.info["command"]), + state.info["command"], + ) + state.info["steps_until_next_cmd"] = jp.where( + done | (state.info["steps_until_next_cmd"] <= 0), + jp.round(jax.random.exponential(key2) * 5.0 / self.dt).astype(jp.int32), + state.info["steps_until_next_cmd"], + ) + for k, v in rewards.items(): + state.metrics[f"reward/{k}"] = v + done = done.astype(reward.dtype) + state = state.replace(data=data, obs=obs, reward=reward, done=done) + return state + + def _get_termination(self, data: mjx.Data, metrics: dict[str, Any]) -> jax.Array: + fall_termination = self.get_gravity(data)[-1] < 0.0 + metrics["termination/fall_termination"] = fall_termination.astype(jp.float32) + return fall_termination | jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() + + def _apply_noise( + self, info: dict[str, Any], value: jax.Array, scale: float + ) -> jax.Array: + info["rng"], noise_rng = jax.random.split(info["rng"]) + noise = 2 * jax.random.uniform(noise_rng, shape=value.shape) - 1 + noisy_value = value + noise * self._config.noise_config.level * scale + return noisy_value + + def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> mjx_env.Observation: + # Ground-truth observations. + gyro = self.get_gyro(data) + gravity = data.site_xmat[self._imu_site_id].T @ jp.array([0, 0, -1]) + joint_angles = data.qpos[7:] + joint_vel = data.qvel[6:] + linvel = self.get_local_linvel(data) + phase = jp.concatenate([jp.cos(info["phase"]), jp.sin(info["phase"])]) + root_pos = data.qpos[:3] + root_quat = data.qpos[3:7] + actuator_torques = data.actuator_force + # Noisy observations. + noise_scales = self._config.noise_config.scales + noisy_gyro = self._apply_noise(info, gyro, noise_scales.gyro) + noisy_gravity = self._apply_noise(info, gravity, noise_scales.gravity) + noisy_joint_angles = self._apply_noise(info, joint_angles, noise_scales.joint_pos) + noisy_joint_vel = self._apply_noise(info, joint_vel, noise_scales.joint_vel) + noisy_linvel = self._apply_noise(info, linvel, noise_scales.linvel) + state = jp.hstack( + [ + noisy_linvel, + noisy_gyro, + noisy_gravity, + info["command"], + noisy_joint_angles - self._init_q[7:], + noisy_joint_vel, + info["last_act"], + phase, + ] + ) + privileged_state = jp.hstack( + [ + state, + # Unnoised. + gyro, + gravity, + linvel, + joint_angles - self._init_q[7:], + joint_vel, + # Extra. + actuator_torques, + root_pos, + root_quat, + ] + ) + return { + "state": state, + "privileged_state": privileged_state, + } + + def _get_reward( + self, + data: mjx.Data, + action: jax.Array, + info: dict[str, Any], + metrics: dict[str, Any], + done: jax.Array, + ) -> dict[str, jax.Array]: + del metrics # Unused. + return { + "termination": done, + "alive": jp.array(1.0) - done, + "tracking": self._reward_tracking(info["command"], info), + "lin_vel_z": self._cost_lin_vel_z(info["filtered_linvel"]), + "ang_vel_xy": self._cost_ang_vel_xy(info["filtered_angvel"]), + "orientation": self._cost_orientation(self.get_gravity(data)), + "feet_phase": self._reward_feet_phase(data, info["phase"]), + "torques": self._cost_torques(data.actuator_force), + "action_rate": self._cost_action_rate(action, info["last_act"]), + "energy": self._cost_energy(data.qvel, data.actuator_force), + "collision": self._cost_collision(data), + "pose": self._cost_pose(data.qpos, info["command"]), + } + + def _reward_tracking(self, commands: jax.Array, info: dict[str, Any]) -> jax.Array: + lin_vel_error = jp.sum(jp.square(commands[:2] - info["filtered_linvel"][:2])) + r_linvel = jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma) + ang_vel_error = jp.square(commands[2] - info["filtered_angvel"][2]) + r_angvel = jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma) + return r_linvel + 0.5 * r_angvel + + def _cost_lin_vel_z(self, local_linvel) -> jax.Array: + return jp.square(local_linvel[2]) + + def _cost_ang_vel_xy(self, local_angvel) -> jax.Array: + return jp.sum(jp.square(local_angvel[:2])) + + def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array: + return jp.sum(jp.square(torso_zaxis[:2])) + + def _cost_torques(self, torques: jax.Array) -> jax.Array: + return jp.sum(jp.abs(torques)) + + def _cost_energy(self, qvel: jax.Array, qfrc_actuator: jax.Array) -> jax.Array: + torques = qfrc_actuator / self._actuator_torques + return jp.sum(jp.abs(qvel[6:] * torques)) + + def _cost_action_rate(self, act: jax.Array, last_act: jax.Array) -> jax.Array: + return jp.sum(jp.square(act - last_act)) + + def _cost_collision(self, data: mjx.Data) -> jax.Array: + # Hand - thigh. + c = geoms_colliding(data, self._left_hand_geom_id, self._left_thigh_geom_id) + c |= geoms_colliding(data, self._right_hand_geom_id, self._right_thigh_geom_id) + # Foot - foot. + c |= geoms_colliding(data, self._left_foot_geom_id, self._right_foot_geom_id) + # Shin - shin. + c |= geoms_colliding( + data, + self._left_shin_geom_id, + self._right_shin_geom_id, + ) + # Thigh - thigh. + c |= geoms_colliding( + data, + self._left_thigh_geom_id, + self._right_thigh_geom_id, + ) + return jp.any(c) + + def _cost_pose(self, qpos: jax.Array, commands: jax.Array) -> jax.Array: + # Uniform weights when standing still. + weights = jp.where( + jp.linalg.norm(commands) < 0.01, + jp.ones_like(self._weights), + self._weights, + ) + # Reduce hip roll weight when lateral command is high. + lateral_cmd = jp.abs(commands[1]) + hip_roll_weight = jp.where(lateral_cmd > 0.3, 0.01, 1.0) + weights = weights.at[21].set(hip_roll_weight) + weights = weights.at[27].set(hip_roll_weight) + return jp.sum(jp.square(qpos[7:] - self._init_q[7:]) * weights) + + def _reward_feet_phase(self, data: mjx.Data, phase: jax.Array) -> jax.Array: + foot_z = data.site_xpos[self._feet_site_id][..., -1] + rz = gait.get_rz(phase, swing_height=self._config.reward_config.max_foot_height) + error = jp.sum(jp.square(foot_z - rz)) + return jp.exp(-error / 0.01) + + def sample_command(self, rng: jax.Array, x_k: jax.Array) -> jax.Array: + rng, y_rng, w_rng, z_rng = jax.random.split(rng, 4) + y_k = jax.random.uniform( + y_rng, shape=(3,), minval=self._cmd_min, maxval=self._cmd_max + ) + z_k = jax.random.bernoulli(z_rng, self._cmd_zero_prob, shape=(3,)) + w_k = jax.random.bernoulli(w_rng, 0.5, shape=(3,)) + return x_k - w_k * (x_k - y_k * z_k) + + def apply_push(self, state: mjx_env.State) -> mjx_env.State: + state.info["rng"], push1_rng, push2_rng = jax.random.split(state.info["rng"], 3) + push_theta = jax.random.uniform(push1_rng, maxval=2 * jp.pi) + push_magnitude = jax.random.uniform( + push2_rng, + minval=self._config.push_config.magnitude_range[0], + maxval=self._config.push_config.magnitude_range[1], + ) + push = jp.array([jp.cos(push_theta), jp.sin(push_theta)]) + push *= jp.mod(state.info["push_step"] + 1, state.info["push_interval_steps"]) == 0 + push *= self._config.push_config.enable + state.info["push"] = push + state.info["push_step"] += 1 + qvel = state.data.qvel + qvel = qvel.at[:2].set(push * push_magnitude + qvel[:2]) + data = state.data.replace(qvel=qvel) + state = state.replace(data=data) + return state diff --git a/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml b/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml new file mode 100644 index 000000000..7ef7d76c0 --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/xmls/apollo_mjx_feetonly.xml @@ -0,0 +1,538 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml b/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml new file mode 100644 index 000000000..2ad479f09 --- /dev/null +++ b/mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/config/locomotion_params.py b/mujoco_playground/config/locomotion_params.py index eaa7587d8..cc777a4e1 100644 --- a/mujoco_playground/config/locomotion_params.py +++ b/mujoco_playground/config/locomotion_params.py @@ -134,6 +134,19 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict: value_obs_key="privileged_state", ) + elif env_name in ("ApolloJoystickFlatTerrain",): + rl_config.num_timesteps = 200_000_000 + rl_config.num_evals = 20 + rl_config.clipping_epsilon = 0.2 + rl_config.num_resets_per_eval = 1 + rl_config.entropy_cost = 0.005 + rl_config.network_factory = config_dict.create( + policy_hidden_layer_sizes=(512, 256, 128), + value_hidden_layer_sizes=(512, 256, 128), + policy_obs_key="state", + value_obs_key="privileged_state", + ) + elif env_name in ( "BarkourJoystick", "H1InplaceGaitTracking", diff --git a/mujoco_playground/experimental/learning/apollo_joystick.ipynb b/mujoco_playground/experimental/learning/apollo_joystick.ipynb new file mode 100644 index 000000000..a6a405f61 --- /dev/null +++ b/mujoco_playground/experimental/learning/apollo_joystick.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n", + "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n", + "os.environ[\"XLA_FLAGS\"] = xla_flags\n", + "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", + "os.environ[\"MUJOCO_GL\"] = \"egl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "import json\n", + "import pickle\n", + "from datetime import datetime\n", + "\n", + "import jax\n", + "import mediapy as media\n", + "import mujoco\n", + "import numpy as np\n", + "from brax.training.agents.ppo import networks as ppo_networks\n", + "from brax.training.agents.ppo import train as ppo\n", + "from etils import epath\n", + "from flax.training import orbax_utils\n", + "from orbax import checkpoint as ocp\n", + "\n", + "from mujoco_playground import registry, wrapper\n", + "from mujoco_playground.config import locomotion_params\n", + "from mujoco_playground.experimental.utils.plotting import TrainingPlotter\n", + "\n", + "# Enable persistent compilation cache.\n", + "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n", + "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n", + "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_name = \"ApolloJoystickFlatTerrain\"\n", + "env_cfg = registry.get_default_config(env_name)\n", + "randomizer = registry.get_domain_randomizer(env_name)\n", + "ppo_params = locomotion_params.brax_ppo_config(env_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "env_cfg.reward_config.scales.energy = -1e-5\n", + "env_cfg.reward_config.scales.action_rate = -1e-3\n", + "env_cfg.reward_config.scales.torques = 0.0\n", + "\n", + "env_cfg.noise_config.level = 0.0 # 1.0\n", + "env_cfg.push_config.enable = True\n", + "env_cfg.push_config.magnitude_range = [0.1, 2.0]\n", + "\n", + "ppo_params.num_timesteps = 150_000_000\n", + "ppo_params.num_evals = 15" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SUFFIX = None\n", + "FINETUNE_PATH = None\n", + "\n", + "# Generate unique experiment name.\n", + "now = datetime.now()\n", + "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n", + "exp_name = f\"{env_name}-{timestamp}\"\n", + "if SUFFIX is not None:\n", + " exp_name += f\"-{SUFFIX}\"\n", + "print(f\"{exp_name}\")\n", + "\n", + "# Possibly restore from the latest checkpoint.\n", + "if FINETUNE_PATH is not None:\n", + " FINETUNE_PATH = epath.Path(FINETUNE_PATH)\n", + " latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n", + " latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n", + " latest_ckpts.sort(key=lambda x: int(x.name))\n", + " latest_ckpt = latest_ckpts[-1]\n", + " restore_checkpoint_path = latest_ckpt\n", + " print(f\"Restoring from: {restore_checkpoint_path}\")\n", + "else:\n", + " restore_checkpoint_path = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n", + "ckpt_path.mkdir(parents=True, exist_ok=True)\n", + "print(f\"{ckpt_path}\")\n", + "\n", + "with open(ckpt_path / \"config.json\", \"w\") as fp:\n", + " json.dump(env_cfg.to_json(), fp, indent=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter = TrainingPlotter(max_timesteps=ppo_params.num_timesteps, figsize=(15, 10))\n", + "\n", + "\n", + "def progress(num_steps, metrics):\n", + " plotter.update(num_steps, metrics)\n", + "\n", + "\n", + "def policy_params_fn(current_step, make_policy, params):\n", + " del make_policy # Unused.\n", + " orbax_checkpointer = ocp.PyTreeCheckpointer()\n", + " save_args = orbax_utils.save_args_from_target(params)\n", + " path = ckpt_path / f\"{current_step}\"\n", + " orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n", + "\n", + "\n", + "training_params = dict(ppo_params)\n", + "del training_params[\"network_factory\"]\n", + "\n", + "train_fn = functools.partial(\n", + " ppo.train,\n", + " **training_params,\n", + " network_factory=functools.partial(\n", + " ppo_networks.make_ppo_networks, **ppo_params.network_factory\n", + " ),\n", + " restore_checkpoint_path=restore_checkpoint_path,\n", + " progress_fn=progress,\n", + " wrap_env_fn=wrapper.wrap_for_brax_training,\n", + " policy_params_fn=policy_params_fn,\n", + " randomization_fn=randomizer,\n", + ")\n", + "\n", + "env = registry.load(env_name, config=env_cfg)\n", + "eval_env = registry.load(env_name, config=env_cfg)\n", + "make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)\n", + "if len(plotter.times) > 1:\n", + " print(f\"time to jit: {plotter.times[1] - plotter.times[0]}\")\n", + " print(f\"time to train: {plotter.times[-1] - plotter.times[1]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "inference_fn = make_inference_fn(params, deterministic=True)\n", + "jit_inference_fn = jax.jit(inference_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Save normalizer and policy params to the checkpoint dir.\n", + "normalizer_params, policy_params, value_params = params\n", + "with open(ckpt_path / \"params.pkl\", \"wb\") as f:\n", + " data = {\n", + " \"normalizer_params\": normalizer_params,\n", + " \"policy_params\": policy_params,\n", + " \"value_params\": value_params,\n", + " }\n", + " pickle.dump(data, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from mujoco_playground._src.gait import draw_joystick_command\n", + "\n", + "eval_env = registry.load(env_name, config=env_cfg)\n", + "jit_reset = jax.jit(eval_env.reset)\n", + "jit_step = jax.jit(eval_env.step)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = jax.random.PRNGKey(12345)\n", + "rollout = []\n", + "modify_scene_fns = []\n", + "state = jit_reset(rng)\n", + "for i in range(env_cfg.episode_length):\n", + " act_rng, rng = jax.random.split(rng)\n", + " ctrl, _ = jit_inference_fn(state.obs, act_rng)\n", + " state = jit_step(state, ctrl)\n", + " if state.done:\n", + " print(\"something bad happened\")\n", + " break\n", + " rollout.append(state)\n", + " xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso_link\").id])\n", + " xyz += np.array([0, 0.0, 0])\n", + " x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n", + " yaw = -np.arctan2(x_axis[1], x_axis[0])\n", + " modify_scene_fns.append(\n", + " functools.partial(\n", + " draw_joystick_command,\n", + " cmd=state.info[\"command\"],\n", + " xyz=xyz,\n", + " theta=yaw,\n", + " scl=np.linalg.norm(state.info[\"command\"]),\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "render_every = 2\n", + "fps = 1.0 / eval_env.dt / render_every\n", + "print(f\"fps: {fps}\")\n", + "traj = rollout[::render_every]\n", + "mod_fns = modify_scene_fns[::render_every]\n", + "\n", + "scene_option = mujoco.MjvOption()\n", + "scene_option.geomgroup[2] = True\n", + "scene_option.geomgroup[3] = False\n", + "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n", + "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n", + "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n", + "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False\n", + "\n", + "frames = eval_env.render(\n", + " traj,\n", + " camera=\"track\",\n", + " scene_option=scene_option,\n", + " width=640,\n", + " height=480,\n", + " modify_scene_fns=mod_fns,\n", + ")\n", + "media.show_video(frames, fps=fps, loop=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx b/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx new file mode 100644 index 000000000..604e6e78b Binary files /dev/null and b/mujoco_playground/experimental/sim2sim/onnx/apollo_policy.onnx differ diff --git a/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py b/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py new file mode 100644 index 000000000..d879ead9f --- /dev/null +++ b/mujoco_playground/experimental/sim2sim/play_apollo_joystick.py @@ -0,0 +1,142 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Deploy an MJX policy in ONNX format to C MuJoCo and play with it.""" + +import mujoco +import numpy as np +import onnxruntime as rt +from etils import epath +from mujoco import viewer + +from mujoco_playground._src.locomotion.apollo import constants as apollo_constants +from mujoco_playground._src.locomotion.apollo.base import get_assets +from mujoco_playground.experimental.sim2sim.gamepad_reader import Gamepad + +_HERE = epath.Path(__file__).parent +_ONNX_DIR = _HERE / "onnx" + + +class OnnxController: + """ONNX controller for the Booster Apollo humanoid.""" + + def __init__( + self, + policy_path: str, + default_angles: np.ndarray, + ctrl_dt: float, + n_substeps: int, + action_scale: float = 0.5, + vel_scale_x: float = 1.0, + vel_scale_y: float = 1.0, + vel_scale_rot: float = 1.0, + ): + self._output_names = ["continuous_actions"] + self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) + + self._action_scale = action_scale + self._default_angles = default_angles + self._last_action = np.zeros_like(default_angles, dtype=np.float32) + + self._counter = 0 + self._n_substeps = n_substeps + self._ctrl_dt = ctrl_dt + + self._phase = np.array([0.0, np.pi]) + self._base_phase_dt = 2 * np.pi * ctrl_dt # Store base phase_dt without frequency + + self._joystick = Gamepad( + vel_scale_x=vel_scale_x, + vel_scale_y=vel_scale_y, + vel_scale_rot=vel_scale_rot, + deadzone=0.03, + ) + + def get_obs(self, model, data) -> np.ndarray: + linvel = data.sensor("local_linvel").data + gyro = data.sensor("gyro").data + imu_xmat = data.site_xmat[model.site("imu").id].reshape(3, 3) + gravity = imu_xmat.T @ np.array([0, 0, -1]) + joint_angles = data.qpos[7:] - self._default_angles + joint_velocities = data.qvel[6:] + command = self._joystick.get_command() + ph = self._phase if np.linalg.norm(command) >= 0.01 else np.ones(2) * np.pi + phase = np.concatenate([np.cos(ph), np.sin(ph)]) + obs = np.hstack( + [ + linvel, + gyro, + gravity, + command, + joint_angles, + joint_velocities, + self._last_action, + phase, + ] + ) + return obs.astype(np.float32) + + def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None: + self._counter += 1 + if self._counter % self._n_substeps == 0: + obs = self.get_obs(model, data) + onnx_input = {"obs": obs.reshape(1, -1)} + onnx_pred = self._policy.run(self._output_names, onnx_input)[0][0] + self._last_action = onnx_pred.copy() + data.ctrl[:] = onnx_pred * self._action_scale + self._default_angles + command = self._joystick.get_command() + cmd_magnitude = np.linalg.norm(command) + if cmd_magnitude < 0.01: + gait_freq = 1.25 + else: + gait_freq = 1.25 + 0.5 * min(cmd_magnitude, 1.5) / 1.5 + phase_dt = self._base_phase_dt * gait_freq + phase_tp1 = self._phase + phase_dt + self._phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi + + +def load_callback(model=None, data=None): + mujoco.set_mjcb_control(None) + + model = mujoco.MjModel.from_xml_path( + apollo_constants.FEET_ONLY_FLAT_TERRAIN_XML.as_posix(), + assets=get_assets(), + ) + data = mujoco.MjData(model) + + mujoco.mj_resetDataKeyframe(model, data, 0) + + ctrl_dt = 0.02 + sim_dt = 0.005 + n_substeps = int(round(ctrl_dt / sim_dt)) + model.opt.timestep = sim_dt + + policy = OnnxController( + policy_path=(_ONNX_DIR / "apollo_policy.onnx").as_posix(), + default_angles=np.array(model.keyframe("knees_bent").qpos[7:]), + ctrl_dt=ctrl_dt, + n_substeps=n_substeps, + action_scale=0.5, + vel_scale_x=1.5, + vel_scale_y=0.8, + vel_scale_rot=1.5, + ) + + mujoco.set_mjcb_control(policy.get_control) + + return model, data + + +if __name__ == "__main__": + viewer.launch(loader=load_callback) diff --git a/mujoco_playground/experimental/utils/plotting.py b/mujoco_playground/experimental/utils/plotting.py new file mode 100644 index 000000000..be7809abb --- /dev/null +++ b/mujoco_playground/experimental/utils/plotting.py @@ -0,0 +1,337 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from datetime import datetime +from typing import Any, Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from IPython.display import clear_output, display + + +class TrainingPlotter: + def __init__( + self, + max_timesteps: int = 50_000_000, + figsize: Tuple[int, int] = (12, 8), + max_cols: int = 3, + ): + self.max_timesteps = max_timesteps + self.max_cols = max_cols + + # Default main metrics that we always want to plot. + self.default_metrics = [ + "eval/episode_reward", + "eval/avg_episode_length", + "steps_per_second", + ] + + self.metrics = [] + self.metrics_std = [] + self.metric_labels = [] + self.error_metrics = [] + self.error_metrics_std = [] + self.error_metric_labels = [] + self.reward_detail_metrics = [] + self.reward_detail_metrics_std = [] + self.reward_detail_metric_labels = [] + self.termination_metrics = [] + self.termination_metrics_std = [] + self.termination_metric_labels = [] + + self.x_data = [] + self.times = [datetime.now()] + self.metrics_data = {} + self.metrics_std_data = {} + + self.fps_data = [] # Store calculated steps per second + + # Use default matplotlib style + plt.rcParams["axes.grid"] = False + plt.rcParams["axes.edgecolor"] = "#888888" + plt.rcParams["axes.linewidth"] = 0.8 + + # Create initial figure and axes - we'll resize this later + n_cols = min(self.max_cols, 1) # Start with at least 1 column, but respect max_cols + n_rows = 1 + self.fig, self.axes = plt.subplots(n_rows, n_cols, figsize=figsize) + self.axes = np.array([[self.axes]]) + + # Set figure background to white for clean look + self.fig.patch.set_facecolor("white") + + # Set up the layout with reasonable spacing + plt.tight_layout(pad=2.5, h_pad=1.5, w_pad=1.0) + + def _get_label_from_metric(self, metric: str) -> str: + parts = metric.split("/") + if len(parts) > 1: + label = parts[-1] + else: + label = metric + if label == "episode_reward": + return "reward_per_episode" + elif label == "avg_episode_length": + return "episode_length" + elif label == "steps_per_second": + return "steps_per_second" + else: + return label + + def _initialize_metrics(self, metrics: Dict[str, Any]) -> None: + """Initialize all metrics from the first metrics dictionary.""" + # Start with default metrics + self.metrics = self.default_metrics.copy() + self.metrics_std = [f"{m}_std" for m in self.metrics] + self.metric_labels = [self._get_label_from_metric(m) for m in self.metrics] + + # Initialize data storage for default metrics + for metric in self.metrics: + self.metrics_data[metric] = [] + for metric_std in self.metrics_std: + self.metrics_std_data[metric_std] = [] + + # Find all reward detail metrics (eval/episode_reward/*) + reward_prefix = "eval/episode_reward/" + for key in metrics: + if ( + key.startswith(reward_prefix) + and not key.endswith("_std") + and key != "eval/episode_reward" + ): + self.reward_detail_metrics.append(key) + self.reward_detail_metrics_std.append(f"{key}_std") + label = key[len(reward_prefix) :] + self.reward_detail_metric_labels.append(label) # Keep underscores + + # Initialize data storage + self.metrics_data[key] = [] + self.metrics_std_data[f"{key}_std"] = [] + + # Find all error metrics (eval/episode_error/*) + error_prefix = "eval/episode_error/" + for key in metrics: + if key.startswith(error_prefix) and not key.endswith("_std"): + self.error_metrics.append(key) + self.error_metrics_std.append(f"{key}_std") + label = key[len(error_prefix) :] + self.error_metric_labels.append(label) # Keep underscores + + # Initialize data storage + self.metrics_data[key] = [] + self.metrics_std_data[f"{key}_std"] = [] + + # Find all termination metrics (eval/episode_termination/*) + termination_prefix = "eval/episode_termination/" + for key in metrics: + if key.startswith(termination_prefix) and not key.endswith("_std"): + self.termination_metrics.append(key) + self.termination_metrics_std.append(f"{key}_std") + label = key[len(termination_prefix) :] + self.termination_metric_labels.append(label) # Keep underscores + + # Initialize data storage + self.metrics_data[key] = [] + self.metrics_std_data[f"{key}_std"] = [] + + def update(self, num_steps: int, metrics: Dict[str, float]) -> None: + self.x_data.append(num_steps) + current_time = datetime.now() + self.times.append(current_time) + + # Calculate steps per second if we have at least two data points + if len(self.x_data) > 1: + time_diff = (current_time - self.times[-2]).total_seconds() + steps_diff = self.x_data[-1] - self.x_data[-2] + if time_diff > 0: + fps = steps_diff / time_diff + else: + fps = 0 + self.fps_data.append(fps) + else: + self.fps_data.append(0) # First point has no previous data to compare + + # Initialize metrics if this is the first update. + if len(self.x_data) == 1: + self._initialize_metrics(metrics) + # Add fps to metrics data structure + self.metrics_data["steps_per_second"] = [] + self.metrics_std_data["steps_per_second_std"] = [] + + # Update all metrics data. + all_metrics = ( + self.metrics + + self.reward_detail_metrics + + self.error_metrics + + self.termination_metrics + ) + all_metrics_std = ( + self.metrics_std + + self.reward_detail_metrics_std + + self.error_metrics_std + + self.termination_metrics_std + ) + + for metric in all_metrics: + if metric == "steps_per_second": + self.metrics_data[metric].append(self.fps_data[-1]) + elif metric in metrics: + self.metrics_data[metric].append(metrics[metric]) + else: + last_value = self.metrics_data[metric][-1] if self.metrics_data[metric] else 0 + self.metrics_data[metric].append(last_value) + + for metric_std in all_metrics_std: + if metric_std in metrics: + self.metrics_std_data[metric_std].append(metrics[metric_std]) + else: + last_value = ( + self.metrics_std_data[metric_std][-1] + if self.metrics_std_data[metric_std] + else 0 + ) + self.metrics_std_data[metric_std].append(last_value) + + clear_output(wait=True) + + # Combine all metrics for plotting. + all_metrics = ( + self.metrics + + self.reward_detail_metrics + + self.error_metrics + + self.termination_metrics + ) + all_metrics_std = ( + self.metrics_std + + self.reward_detail_metrics_std + + self.error_metrics_std + + self.termination_metrics_std + ) + all_labels = ( + self.metric_labels + + self.reward_detail_metric_labels + + self.error_metric_labels + + self.termination_metric_labels + ) + + # Calculate grid dimensions using max_cols. + total_plots = len(all_metrics) + n_cols = min(self.max_cols, total_plots) # Use max_cols parameter. + n_rows = (total_plots + n_cols - 1) // n_cols # Ceiling division. + + # Check if we need to resize the axes grid + if n_rows > self.axes.shape[0] or n_cols > self.axes.shape[1]: + plt.close(self.fig) + # Calculate a better figure size based on the number of plots and columns + width = max(12, n_cols * 3.5) # 3.5 inches per column + height = max(8, n_rows * 2.5) # 2.5 inches per row + self.fig, self.axes = plt.subplots(n_rows, n_cols, figsize=(width, height)) + + # Handle case where there's only one plot. + if n_rows == 1 and n_cols == 1: + self.axes = np.array([[self.axes]]) + elif n_rows == 1: + self.axes = np.array([self.axes]) + elif n_cols == 1: + self.axes = np.array([[ax] for ax in self.axes]) + + # Plot all metrics + self._plot_metrics(all_metrics, all_metrics_std, all_labels, self.axes) + + # Add a single x-axis label at the bottom of the figure. + self.fig.text( + 0.5, 0.01, "# environment steps", ha="center", fontsize=12, fontweight="bold" + ) + + # Update layout and display. + self.fig.tight_layout(pad=2.5, h_pad=1.5, w_pad=1.0) + self.fig.subplots_adjust(bottom=0.08) + display(self.fig) + + def _plot_metrics( + self, + metrics_list: List[str], + metrics_std_list: List[str], + labels_list: List[str], + axes_grid: np.ndarray, + ) -> None: + """Plot a set of metrics on the given axes grid.""" + for i, (metric, metric_std, label) in enumerate( + zip(metrics_list, metrics_std_list, labels_list) + ): + row, col = i // axes_grid.shape[1], i % axes_grid.shape[1] + if row < axes_grid.shape[0] and col < axes_grid.shape[1]: + ax = axes_grid[row][col] + ax.clear() + ax.set_xlim([0, self.max_timesteps * 1.25]) + + # Remove x-axis labels from all subplots + ax.set_xlabel("") + + # Make tick labels smaller to save space + ax.tick_params(axis="both", which="major", labelsize=9) + + # Add subtle grid for better readability + ax.grid(True, linestyle="-", linewidth=0.5, alpha=0.2) + + # Clean background + ax.set_facecolor("white") + + # Format y-axis with fewer decimal places for cleaner look + ax.ticklabel_format(axis="y", style="plain", useOffset=False) + + y_values = self.metrics_data[metric] + yerr_values = ( + self.metrics_std_data[metric_std] + if metric_std in self.metrics_std_data + else None + ) + + if y_values: + # Add prefix based on metric type + prefix = "" + if "eval/episode_error/" in metric: + prefix = "error/" + elif "eval/episode_reward/" in metric: + prefix = "reward/" + elif "eval/episode_termination/" in metric: + prefix = "termination/" + + # Use smaller font for title to save space + ax.set_title( + f"{prefix}{label}: {y_values[-1]:.3f}", fontsize=10, fontweight="bold" + ) + + # Plot the line with improved styling + line = ax.errorbar( + self.x_data, + y_values, + yerr=yerr_values, + color="black", + linewidth=1.5, + elinewidth=0.7, + capsize=2, + ) + + # Add very subtle shading under the curve for better visibility + ax.fill_between(self.x_data, 0, y_values, alpha=0.05, color="black") + + # Hide unused subplots + for i in range(len(metrics_list), axes_grid.shape[0] * axes_grid.shape[1]): + row, col = i // axes_grid.shape[1], i % axes_grid.shape[1] + if row < axes_grid.shape[0] and col < axes_grid.shape[1]: + axes_grid[row][col].set_visible(False) + + def save_figure(self, filename: str) -> None: + self.fig.savefig(filename, dpi=300, bbox_inches="tight")