diff --git a/documentation/source/log_msg.ipynb b/documentation/source/log_msg.ipynb index 6922a816b..dc9949028 100644 --- a/documentation/source/log_msg.ipynb +++ b/documentation/source/log_msg.ipynb @@ -78,7 +78,7 @@ " super(MyMgr, self).__init__()\n", " logger.info(\"Init MyMgr...\")\n", " \n", - " def after_step(self):\n", + " def after_step(self, *args, **kwargs):\n", " logger.info(\"Step {}...\".format(self.episode_step))\n", " return dict()\n", "\n", diff --git a/documentation/source/system_design.ipynb b/documentation/source/system_design.ipynb index bb1569372..ba0d54fee 100644 --- a/documentation/source/system_design.ipynb +++ b/documentation/source/system_design.ipynb @@ -144,7 +144,7 @@ " self.generated_v.before_step([0.5, 0.4]) # set action\n", " \n", "\n", - " def after_step(self):\n", + " def after_step(self, *args, **kwargs):\n", " if self.episode_step == self.generate_ts:\n", " self.generated_v = self.spawn_object(DefaultVehicle, \n", " vehicle_config=dict(), \n", @@ -494,7 +494,7 @@ " heading=0)\n", " self.add_policy(obj.id, IDMPolicy, obj, self.generate_seed())\n", "\n", - " def after_step(self):\n", + " def after_step(self, *args, **kwargs):\n", " for obj in self.spawned_objects.values():\n", " obj.after_step()\n", " if self.episode_step == 180:\n", @@ -572,7 +572,7 @@ " self.generated_v.before_step([0.5, 0.4]) # set action\n", " \n", "\n", - " def after_step(self):\n", + " def after_step(self, *args, **kwargs):\n", " if self.episode_step == self.generate_ts:\n", " self.generated_v = self.spawn_object(DefaultVehicle, \n", " vehicle_config=dict(), \n", diff --git a/metadrive/envs/multigoal_intersection.py b/metadrive/envs/multigoal_intersection.py index f94263e80..74254eb7d 100644 --- a/metadrive/envs/multigoal_intersection.py +++ b/metadrive/envs/multigoal_intersection.py @@ -236,7 +236,7 @@ def after_reset(self): navi.reset(self.agent, dest=self.goals[name]) navi.update_localization(self.agent) - def after_step(self): + def after_step(self, *args, **kwargs): """Update all navigation modules.""" # print("[DEBUG]: after_step in MultiGoalIntersectionNavigationManager") for name, navi in self.navigations.items(): diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index 456ebd2fc..876524902 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -15,6 +15,7 @@ from metadrive.manager.scenario_map_manager import ScenarioMapManager from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.policy.waypoint_policy import WaypointPolicy from metadrive.utils import get_np_random from metadrive.utils.math import wrap_to_pi @@ -63,8 +64,8 @@ ), # If set_static=True, then the agent will not "fall from the sky". This will be helpful if you want to # capture per-frame data for the agent (for example for collecting static sensor data). - # However, the physics engine will not update the position of the agent. So in the visualization, the image will be - # very chunky as the agent will not suddenly move to the next position for each step. + # However, the physics simulation of the agent will be disable too. So in the visualization, the image will be + # very chunky as the agent will suddenly move to the next position for each step. # Set to False for better visualization. set_static=False, @@ -102,6 +103,16 @@ use_bounding_box=False, # Set True to use a cube in visualization to represent every dynamic objects. ) +SCENARIO_WAYPOINT_ENV_CONFIG = dict( + # How many waypoints will be used at each environmental step. Checkout ScenarioWaypointEnv for details. + waypoint_horizon=5, + agent_policy=WaypointPolicy, + + # Must set this to True, otherwise the agent will drift away from the waypoint when doing + # "self.engine.step(self.config["decision_repeat"])" in "_step_simulator". + set_static=True, +) + class ScenarioEnv(BaseEnv): @classmethod @@ -418,6 +429,40 @@ def set_scenario(self, scenario_data): self.engine.data_manager.set_scenario(scenario_data) +class ScenarioWaypointEnv(ScenarioEnv): + """ + This environment use WaypointPolicy. Even though the environment still runs in 10 HZ, we allow the external + waypoint generator generates up to 5 waypoints at each step (controlled by config "waypoint_horizon"). + Say at step t, we receive 5 waypoints. Then we will set the agent states for t+1, t+2, t+3, t+4, t+5 if at + t+1 ~ t+4 no additional waypoints are received. Here is the full timeline: + + step t=0: env.reset(), initial positions/obs are sent out. This corresponds to the t=0 or t=10 in WOMD dataset + (TODO: we should allow control on the meaning of the t=0) + step t=1: env.step(), agent receives 5 waypoints, we will record the waypoint sequences. Set agent state for t=1, + and send out the obs for t=1. + step t=2: env.step(), it's possible to get action=None, which means the agent will use the cached waypoint t=2, + and set the agent state for t=2. The obs for t=2 will be sent out. If new waypoints are received, we will \ + instead set agent state to the first new waypoint. + step t=3: ... continues the loop and receives action=None or new waypoints. + step t=4: ... + step t=5: ... + step t=6: if we only receive action at t=1, and t=2~t=5 are all None, then this step will force to receive + new waypoints. We will set the agent state to the first new waypoint. + + Most of the functions are implemented in WaypointPolicy. + """ + @classmethod + def default_config(cls): + config = super(ScenarioWaypointEnv, cls).default_config() + config.update(SCENARIO_WAYPOINT_ENV_CONFIG) + return config + + def _post_process_config(self, config): + ret = super(ScenarioWaypointEnv, self)._post_process_config(config) + assert config["set_static"], "Waypoint policy requires set_static=True" + return ret + + if __name__ == "__main__": env = ScenarioEnv( { diff --git a/metadrive/examples/run_waypoint_policy.py b/metadrive/examples/run_waypoint_policy.py new file mode 100644 index 000000000..32ec4c460 --- /dev/null +++ b/metadrive/examples/run_waypoint_policy.py @@ -0,0 +1,54 @@ +""" +This script demonstrates how to use the Waypoint Policy, which feeds (5, 2), that is 5 waypoints, to the ego agent. +The waypoint is in the local frame of the vehicle, where the x-axis is the forward direction of the vehicle and +the y-axis is the left direction of the vehicle. +""" +import argparse + +import numpy as np + +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioWaypointEnv + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--top_down", "--topdown", action="store_true") + parser.add_argument("--waymo", action="store_true") + args = parser.parse_args() + extra_args = dict(film_size=(2000, 2000)) if args.top_down else {} + asset_path = AssetLoader.asset_path + use_waymo = args.waymo + + waypoint_horizon = 5 + + cfg = { + "map_region_size": 1024, # use a large number if your map is toooooo big + "sequential_seed": True, + "use_render": False, + "data_directory": AssetLoader.file_path(asset_path, "waymo" if use_waymo else "nuscenes", unix_style=False), + "num_scenarios": 3 if use_waymo else 10, + "waypoint_horizon": waypoint_horizon, + } + + try: + env = ScenarioWaypointEnv(cfg) + o, _ = env.reset() + i = 0 + for _ in range(0, 100000): + if i % waypoint_horizon == 0: + # X-coordinate is the forward direction of the vehicle, Y-coordinate is the left of the vehicle + x_displacement = np.linspace(1, 6, waypoint_horizon) + y_displacement = np.linspace(0, 0.05, waypoint_horizon) # pos y is left. + action = dict(position=np.stack([x_displacement, y_displacement], axis=1)) + + else: + action = None + + o, r, tm, tc, info = env.step(actions=action) + env.render(mode="top_down") + i += 1 + if tm or tc: + env.reset() + i = 0 + finally: + env.close() diff --git a/metadrive/policy/README.md b/metadrive/policy/README.md index 996966cc7..31d2e6594 100644 --- a/metadrive/policy/README.md +++ b/metadrive/policy/README.md @@ -10,5 +10,7 @@ They are: * ReplayEgoCarPolicy(BasePolicy): Make the ego car replay the logged trajectory. +* WaypointPolicy(BasePolicy): A policy that follows the waypoint. + Change the `env_config["agent_policy"]` to `IDMPolicy|WaymoIDMPolicy|ReplayEgoCarPolicy` to let the ego car follow different policies. diff --git a/metadrive/policy/replay_policy.py b/metadrive/policy/replay_policy.py index f66b3097d..203e623d9 100644 --- a/metadrive/policy/replay_policy.py +++ b/metadrive/policy/replay_policy.py @@ -66,9 +66,8 @@ def act(self, *args, **kwargs): # If set_static, then the agent will not "fall from the sky". # However, the physics engine will not update the position of the agent. - # So in the visualization, the image will be very chunky as the agent will not suddenly move to the next + # So in the visualization, the image will be very chunky as the agent will suddenly move to the next # position for each step. - if self.engine.global_config.get("set_static", False): self.control_object.set_static(True) diff --git a/metadrive/policy/waypoint_policy.py b/metadrive/policy/waypoint_policy.py new file mode 100644 index 000000000..1cb1bdd80 --- /dev/null +++ b/metadrive/policy/waypoint_policy.py @@ -0,0 +1,98 @@ +import gymnasium as gym +import numpy as np + +from metadrive.policy.base_policy import BasePolicy +from metadrive.utils import waypoint_utils + + +class WaypointPolicy(BasePolicy): + """ + This policy will have the trajectory data being overwritten on the fly. + """ + def __init__(self, obj, seed): + super(WaypointPolicy, self).__init__(control_object=obj, random_seed=seed) + self.horizon = self.engine.global_config.get("waypoint_horizon", 10) + self.cache = None + self.cache_last_update = 0 + + @classmethod + def get_input_space(cls): + from metadrive.engine.engine_utils import get_global_config + horizon = get_global_config().get("waypoint_horizon", 10) + return gym.spaces.Dict( + dict(position=gym.spaces.Box(float("-inf"), float("inf"), shape=(horizon, 2), dtype=np.float32), ) + ) + + def _convert_to_world_coordinates(self, waypoint_positions): + """ + This function is used to convert the waypoint positions from the local frame to the world frame + """ + obj_heading = np.array(self.control_object.heading_theta).reshape(1, ).repeat(waypoint_positions.shape[0]) + obj_position = np.array(self.control_object.position).reshape(1, 2) + rotated = waypoint_utils.rotate( + waypoint_positions[:, 0], + waypoint_positions[:, 1], + obj_heading, + ) + translated = rotated + obj_position + return translated + + def reset(self): + """ + Reset the policy + """ + self.cache = None + self.cache_last_update = 0 + super(WaypointPolicy, self).reset() + + def act(self, agent_id): + assert self.engine.external_actions is not None + actions = self.engine.external_actions[agent_id] + + if actions is not None: + + waypoint_positions = actions["position"] + assert waypoint_positions.ndim == 2 + assert waypoint_positions.shape[1] == 2 + + world_positions = self._convert_to_world_coordinates(waypoint_positions) + headings = np.array(waypoint_utils.reconstruct_heading(world_positions)) + + # dt should be 0.1s in default settings + dt = self.engine.global_config["physics_world_step_size"] * self.engine.global_config["decision_repeat"] + + angular_velocities = np.array(waypoint_utils.reconstruct_angular_velocity(headings, dt)) + velocities = np.array(waypoint_utils.reconstruct_velocity(world_positions, dt)) + + duration = len(waypoint_positions) + assert duration == self.horizon, "The length of the waypoint positions should be equal to the horizon: {} vs {}".format( + duration, self.horizon + ) + + self.cache = dict( + position=world_positions, + velocity=velocities, + heading=headings, + angular_velocity=angular_velocities, + ) + self.cache_last_update = self.engine.episode_step + + assert self.cache is not None + + cache_index = self.engine.episode_step - self.cache_last_update + assert cache_index < self.horizon, "Cache index out of range: {} vs {}".format(cache_index, self.horizon) + + self.control_object.set_position(self.cache["position"][cache_index]) + self.control_object.set_velocity(self.cache["velocity"][cache_index]) + self.control_object.set_heading_theta(self.cache["heading"][cache_index]) + self.control_object.set_angular_velocity(self.cache["angular_velocity"][cache_index]) + + # A legacy code to set the static mode of the agent + # If set_static, then the agent will not "fall from the sky". + # However, the physics simulation will not apply too to the agent. + # So in the visualization, the image will be very chunky as the agent will suddenly move to the next + # position for each step. + if self.engine.global_config.get("set_static", False): + self.control_object.set_static(True) + + return None # Return None action so the base vehicle will not overwrite the steering & throttle diff --git a/metadrive/tests/test_policy/test_waypoint_policy.py b/metadrive/tests/test_policy/test_waypoint_policy.py new file mode 100644 index 000000000..706371789 --- /dev/null +++ b/metadrive/tests/test_policy/test_waypoint_policy.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +import numpy as np + +from metadrive.constants import HELP_MESSAGE +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioWaypointEnv + + +def test_waypoint_policy(render=False): + """ + Test the waypoint policy by running a scenario in the ScenarioEnv. + """ + asset_path = AssetLoader.asset_path + env = ScenarioWaypointEnv( + { + "sequential_seed": True, + "use_render": False, # True if not args.top_down else False, + "data_directory": AssetLoader.file_path(asset_path, "nuscenes", unix_style=False), + "num_scenarios": 10, + } + ) + o, _ = env.reset() + seen_seeds = set() + seen_seeds.add(env.engine.data_manager.current_scenario["id"]) + # Load the information + scenario = env.engine.data_manager.current_scenario + ego_id = scenario["metadata"]["sdc_id"] + ego_track = scenario["tracks"][ego_id] + ego_traj = ego_track["state"]["position"][..., :2] + # ego_traj_ego = np.array([env.agent.convert_to_local_coordinates(point, env.agent.position) for point in ego_traj]) + + waypoint_horizon = env.engine.global_config["waypoint_horizon"] + + ADEs, FDEs, flag = [], [], True + + episode_step = 0 + replay_traj = [] + for _ in range(1, 100000): + + if episode_step % waypoint_horizon == 0: + # prepare action: + # X-coordinate is the forward direction of the vehicle, Y-coordinate is the left of the vehicle + # Since we start with step 0, the first waypoint should be at step 1. Thus we start from step 1. + local_traj = np.array( + [ + env.agent.convert_to_local_coordinates(point, env.agent.position) + for point in ego_traj[episode_step + 1:episode_step + waypoint_horizon + 1] + ] + ) + action = dict(position=local_traj) + else: + action = None + replay_traj.append(env.agent.position) # Store the pre-step position + o, r, tm, tc, info = env.step(actions=action) + episode_step += 1 + if render: + env.render(mode="top_down") + if tm or tc: + if episode_step > 100: + replay_traj = np.array(replay_traj) + + # Align their shape + if replay_traj.shape[0] > ego_traj.shape[0]: + replay_traj = replay_traj[:ego_traj.shape[0]] + elif replay_traj.shape[0] < ego_traj.shape[0]: + ego_traj = ego_traj[:replay_traj.shape[0]] + + ade = np.mean(np.linalg.norm(replay_traj - ego_traj, axis=-1)) + fde = np.linalg.norm(replay_traj[-1] - ego_traj[-1], axis=-1) + ADEs.append(ade) + FDEs.append(fde) + print( + "For seed {}, horizon: {}, ADE: {}, FDE: {}".format( + env.engine.data_manager.current_scenario["id"], episode_step, ade, fde + ) + ) + else: + # An early terminated episode. Skip. + print( + "Early terminated episode {}, horizon: {}".format( + env.engine.data_manager.current_scenario["id"], episode_step + ) + ) + pass + + episode_step = 0 + env.reset() + replay_traj = [] + if env.engine.data_manager.current_scenario["id"] in seen_seeds: + break + else: + seen_seeds.add(env.engine.data_manager.current_scenario["id"]) + scenario = env.engine.data_manager.current_scenario + ego_id = scenario["metadata"]["sdc_id"] + ego_track = scenario["tracks"][ego_id] + ego_traj = ego_track["state"]["position"][..., :2] + + print(f"Mean ADE: {np.mean(ADEs)}, Mean FDE: {np.mean(FDEs)}") + + mean_ade = np.mean(ADEs) + mean_fde = np.mean(FDEs) + assert mean_ade < 1e-4 and mean_fde < 1e-4 + + env.close() + + +if __name__ == '__main__': + test_waypoint_policy(render=True) diff --git a/metadrive/utils/waypoint_utils.py b/metadrive/utils/waypoint_utils.py new file mode 100644 index 000000000..b8fee1621 --- /dev/null +++ b/metadrive/utils/waypoint_utils.py @@ -0,0 +1,53 @@ +import numpy as np + + +def reconstruct_heading(waypoints): + """ + Reconstructs the headings based on the waypoints. + return the yaw angle(in world coordinate), with positive value turning left and negative value turning right + """ + # Calculate the headings based on the waypoints + headings = np.arctan2(np.diff(waypoints[:, 1]), np.diff(waypoints[:, 0])) + # Append the last heading to match the length of waypoints + headings = np.append(headings, headings[-1]) + return headings + + +def reconstruct_angular_velocity(headings, time_interval): + """ + Reconstructs the angular velocities based on the headings and time interval. + return in rad/s(in world coordinate), with positive value turning left and negative value turning right + """ + # Calculate the angular velocities + angular_velocities = np.diff(headings) / time_interval + # Append the last angular velocity to match the length of headings + angular_velocities = np.append(angular_velocities, angular_velocities[-1]) + return angular_velocities + + +def reconstruct_velocity(waypoints, dt): + """ + Reconstructs the velocities based on the waypoints and time interval. + """ + diff = np.diff(waypoints, axis=0) + velocitaies = diff / dt + # Append the last velocity to match the length of waypoints + velocitaies = np.append(velocitaies, velocitaies[-1].reshape(1, -1), axis=0) + return velocitaies + + +def rotate(x, y, angle, z=None, assert_shape=True): + """ + Rotate the coordinates (x, y) by the given angle in radians. + """ + if assert_shape: + assert angle.shape == x.shape == y.shape, (angle.shape, x.shape, y.shape) + if z is not None: + assert x.shape == z.shape + other_x_trans = np.cos(angle) * x - np.sin(angle) * y + other_y_trans = np.cos(angle) * y + np.sin(angle) * x + if z is None: + output_coords = np.stack((other_x_trans, other_y_trans), axis=-1) + else: + output_coords = np.stack((other_x_trans, other_y_trans, z), axis=-1) + return output_coords