Skip to content

Waypoint policy #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 14, 2025
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d3af1f7
Implemented a variant of ReplayEgoCarPolicy called "WaypointPolicy", …
WeizhenWang-1210 Apr 8, 2025
15a35b7
Formatted the files
WeizhenWang-1210 Apr 8, 2025
635b5db
Update the after_step of MultiGoalIntersectionNavigationManager to ta…
WeizhenWang-1210 Apr 8, 2025
587b4a7
Merge remote-tracking branch 'origin/waypoint_policy' into waypoint_p…
WeizhenWang-1210 Apr 8, 2025
edf530e
Formatting
WeizhenWang-1210 Apr 8, 2025
5169539
Fix the argument passing by enforcing a method signature checker for …
WeizhenWang-1210 Apr 8, 2025
0b83245
Merge remote-tracking branch 'origin/waypoint_policy' into waypoint_p…
WeizhenWang-1210 Apr 8, 2025
c6a33cd
Formatting
WeizhenWang-1210 Apr 8, 2025
e1fba01
Revert certain changes
pengzhenghao Apr 10, 2025
0f3f13b
format
pengzhenghao Apr 10, 2025
6ca3ba7
Waypoint policy in ego coordinates implemented.
WeizhenWang-1210 Apr 11, 2025
7d2af92
Updated formulation detail
WeizhenWang-1210 Apr 11, 2025
01e4181
Test added
WeizhenWang-1210 Apr 11, 2025
21c9015
Remove uneeded file
WeizhenWang-1210 Apr 11, 2025
db702f6
Updated ExampleManager
WeizhenWang-1210 Apr 11, 2025
8ee8423
Formatted the files
WeizhenWang-1210 Apr 11, 2025
3b77b05
No cubic splint
WeizhenWang-1210 Apr 11, 2025
f0e006b
Merge remote-tracking branch 'origin/waypoint_policy' into waypoint_p…
WeizhenWang-1210 Apr 11, 2025
ab158eb
Formatted the files
WeizhenWang-1210 Apr 11, 2025
58d9c01
args, kwargs update
WeizhenWang-1210 Apr 11, 2025
e2786f5
args, kwargs update
WeizhenWang-1210 Apr 11, 2025
924413c
Maybe relax the mem leak a little bit?
WeizhenWang-1210 Apr 11, 2025
5b52cac
Remove not related code vectorized.py
pengzhenghao Apr 12, 2025
b78e5fb
Roll back a memory leak test
pengzhenghao Apr 12, 2025
5f35f49
Reformat the waypoint policy unit test, and IT DOESN'T TEST THE CLOSE…
pengzhenghao Apr 12, 2025
d60c9f1
Implement the WaypointPolicy
pengzhenghao Apr 13, 2025
1110a6d
Revert some changes
pengzhenghao Apr 13, 2025
e4b94ba
Finish the unit test for WaypointPolicy/Env
pengzhenghao Apr 13, 2025
8bc0da9
Add some TODO. Please fix @Weizhen
pengzhenghao Apr 13, 2025
1d504c4
Fix bugs
pengzhenghao Apr 14, 2025
5f7f8f8
rename and fix
pengzhenghao Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion documentation/source/log_msg.ipynb
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@
" super(MyMgr, self).__init__()\n",
" logger.info(\"Init MyMgr...\")\n",
" \n",
" def after_step(self):\n",
" def after_step(self, *args, **kwargs):\n",
" logger.info(\"Step {}...\".format(self.episode_step))\n",
" return dict()\n",
"\n",
6 changes: 3 additions & 3 deletions documentation/source/system_design.ipynb
Original file line number Diff line number Diff line change
@@ -144,7 +144,7 @@
" self.generated_v.before_step([0.5, 0.4]) # set action\n",
" \n",
"\n",
" def after_step(self):\n",
" def after_step(self, *args, **kwargs):\n",
" if self.episode_step == self.generate_ts:\n",
" self.generated_v = self.spawn_object(DefaultVehicle, \n",
" vehicle_config=dict(), \n",
@@ -494,7 +494,7 @@
" heading=0)\n",
" self.add_policy(obj.id, IDMPolicy, obj, self.generate_seed())\n",
"\n",
" def after_step(self):\n",
" def after_step(self, *args, **kwargs):\n",
" for obj in self.spawned_objects.values():\n",
" obj.after_step()\n",
" if self.episode_step == 180:\n",
@@ -572,7 +572,7 @@
" self.generated_v.before_step([0.5, 0.4]) # set action\n",
" \n",
"\n",
" def after_step(self):\n",
" def after_step(self, *args, **kwargs):\n",
" if self.episode_step == self.generate_ts:\n",
" self.generated_v = self.spawn_object(DefaultVehicle, \n",
" vehicle_config=dict(), \n",
2 changes: 1 addition & 1 deletion metadrive/envs/multigoal_intersection.py
Original file line number Diff line number Diff line change
@@ -236,7 +236,7 @@ def after_reset(self):
navi.reset(self.agent, dest=self.goals[name])
navi.update_localization(self.agent)

def after_step(self):
def after_step(self, *args, **kwargs):
"""Update all navigation modules."""
# print("[DEBUG]: after_step in MultiGoalIntersectionNavigationManager")
for name, navi in self.navigations.items():
49 changes: 47 additions & 2 deletions metadrive/envs/scenario_env.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
from metadrive.manager.scenario_map_manager import ScenarioMapManager
from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.policy.waypoint_policy import WaypointPolicy
from metadrive.utils import get_np_random
from metadrive.utils.math import wrap_to_pi

@@ -63,8 +64,8 @@
),
# If set_static=True, then the agent will not "fall from the sky". This will be helpful if you want to
# capture per-frame data for the agent (for example for collecting static sensor data).
# However, the physics engine will not update the position of the agent. So in the visualization, the image will be
# very chunky as the agent will not suddenly move to the next position for each step.
# However, the physics simulation of the agent will be disable too. So in the visualization, the image will be
# very chunky as the agent will suddenly move to the next position for each step.
# Set to False for better visualization.
set_static=False,

@@ -102,6 +103,16 @@
use_bounding_box=False, # Set True to use a cube in visualization to represent every dynamic objects.
)

SCENARIO_WAYPOINT_ENV_CONFIG = dict(
# How many waypoints will be used at each environmental step. Checkout ScenarioWaypointEnv for details.
waypoint_horizon=5,
agent_policy=WaypointPolicy,

# Must set this to True, otherwise the agent will drift away from the waypoint when doing
# "self.engine.step(self.config["decision_repeat"])" in "_step_simulator".
set_static=True,
)


class ScenarioEnv(BaseEnv):
@classmethod
@@ -418,6 +429,40 @@ def set_scenario(self, scenario_data):
self.engine.data_manager.set_scenario(scenario_data)


class ScenarioWaypointEnv(ScenarioEnv):
"""
This environment use WaypointPolicy. Even though the environment still runs in 10 HZ, we allow the external
waypoint generator generates up to 5 waypoints at each step (controlled by config "waypoint_horizon").
Say at step t, we receive 5 waypoints. Then we will set the agent states for t+1, t+2, t+3, t+4, t+5 if at
t+1 ~ t+4 no additional waypoints are received. Here is the full timeline:

step t=0: env.reset(), initial positions/obs are sent out. This corresponds to the t=0 or t=10 in WOMD dataset
(TODO: we should allow control on the meaning of the t=0)
step t=1: env.step(), agent receives 5 waypoints, we will record the waypoint sequences. Set agent state for t=1,
and send out the obs for t=1.
step t=2: env.step(), it's possible to get action=None, which means the agent will use the cached waypoint t=2,
and set the agent state for t=2. The obs for t=2 will be sent out. If new waypoints are received, we will \
instead set agent state to the first new waypoint.
step t=3: ... continues the loop and receives action=None or new waypoints.
step t=4: ...
step t=5: ...
step t=6: if we only receive action at t=1, and t=2~t=5 are all None, then this step will force to receive
new waypoints. We will set the agent state to the first new waypoint.

Most of the functions are implemented in WaypointPolicy.
"""
@classmethod
def default_config(cls):
config = super(ScenarioWaypointEnv, cls).default_config()
config.update(SCENARIO_WAYPOINT_ENV_CONFIG)
return config

def _post_process_config(self, config):
ret = super(ScenarioWaypointEnv, self)._post_process_config(config)
assert config["set_static"], "Waypoint policy requires set_static=True"
return ret


if __name__ == "__main__":
env = ScenarioEnv(
{
54 changes: 54 additions & 0 deletions metadrive/examples/run_waypoint_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
This script demonstrates how to use the Waypoint Policy, which feeds (5, 2), that is 5 waypoints, to the ego agent.
The waypoint is in the local frame of the vehicle, where the x-axis is the forward direction of the vehicle and
the y-axis is the left direction of the vehicle.
"""
import argparse

import numpy as np

from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioWaypointEnv

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--top_down", "--topdown", action="store_true")
parser.add_argument("--waymo", action="store_true")
args = parser.parse_args()
extra_args = dict(film_size=(2000, 2000)) if args.top_down else {}
asset_path = AssetLoader.asset_path
use_waymo = args.waymo

waypoint_horizon = 5

cfg = {
"map_region_size": 1024, # use a large number if your map is toooooo big
"sequential_seed": True,
"use_render": False,
"data_directory": AssetLoader.file_path(asset_path, "waymo" if use_waymo else "nuscenes", unix_style=False),
"num_scenarios": 3 if use_waymo else 10,
"waypoint_horizon": waypoint_horizon,
}

try:
env = ScenarioWaypointEnv(cfg)
o, _ = env.reset()
i = 0
for _ in range(0, 100000):
if i % waypoint_horizon == 0:
# X-coordinate is the forward direction of the vehicle, Y-coordinate is the left of the vehicle
x_displacement = np.linspace(1, 6, waypoint_horizon)
y_displacement = np.linspace(0, 0.05, waypoint_horizon) # pos y is left.
action = dict(position=np.stack([x_displacement, y_displacement], axis=1))

else:
action = None

o, r, tm, tc, info = env.step(actions=action)
env.render(mode="top_down")
i += 1
if tm or tc:
env.reset()
i = 0
finally:
env.close()
2 changes: 2 additions & 0 deletions metadrive/policy/README.md
Original file line number Diff line number Diff line change
@@ -10,5 +10,7 @@ They are:

* ReplayEgoCarPolicy(BasePolicy): Make the ego car replay the logged trajectory.

* WaypointPolicy(BasePolicy): A policy that follows the waypoint.

Change the `env_config["agent_policy"]` to `IDMPolicy|WaymoIDMPolicy|ReplayEgoCarPolicy` to let the ego car follow different policies.

3 changes: 1 addition & 2 deletions metadrive/policy/replay_policy.py
Original file line number Diff line number Diff line change
@@ -66,9 +66,8 @@ def act(self, *args, **kwargs):

# If set_static, then the agent will not "fall from the sky".
# However, the physics engine will not update the position of the agent.
# So in the visualization, the image will be very chunky as the agent will not suddenly move to the next
# So in the visualization, the image will be very chunky as the agent will suddenly move to the next
# position for each step.

if self.engine.global_config.get("set_static", False):
self.control_object.set_static(True)

98 changes: 98 additions & 0 deletions metadrive/policy/waypoint_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import gymnasium as gym
import numpy as np

from metadrive.policy.base_policy import BasePolicy
from metadrive.utils import waypoint_utils


class WaypointPolicy(BasePolicy):
"""
This policy will have the trajectory data being overwritten on the fly.
"""
def __init__(self, obj, seed):
super(WaypointPolicy, self).__init__(control_object=obj, random_seed=seed)
self.horizon = self.engine.global_config.get("waypoint_horizon", 10)
self.cache = None
self.cache_last_update = 0

@classmethod
def get_input_space(cls):
from metadrive.engine.engine_utils import get_global_config
horizon = get_global_config().get("waypoint_horizon", 10)
return gym.spaces.Dict(
dict(position=gym.spaces.Box(float("-inf"), float("inf"), shape=(horizon, 2), dtype=np.float32), )
)

def _convert_to_world_coordinates(self, waypoint_positions):
"""
This function is used to convert the waypoint positions from the local frame to the world frame
"""
obj_heading = np.array(self.control_object.heading_theta).reshape(1, ).repeat(waypoint_positions.shape[0])
obj_position = np.array(self.control_object.position).reshape(1, 2)
rotated = waypoint_utils.rotate(
waypoint_positions[:, 0],
waypoint_positions[:, 1],
obj_heading,
)
translated = rotated + obj_position
return translated

def reset(self):
"""
Reset the policy
"""
self.cache = None
self.cache_last_update = 0
super(WaypointPolicy, self).reset()

def act(self, agent_id):
assert self.engine.external_actions is not None
actions = self.engine.external_actions[agent_id]

if actions is not None:

waypoint_positions = actions["position"]
assert waypoint_positions.ndim == 2
assert waypoint_positions.shape[1] == 2

world_positions = self._convert_to_world_coordinates(waypoint_positions)
headings = np.array(waypoint_utils.reconstruct_heading(world_positions))

# dt should be 0.1s in default settings
dt = self.engine.global_config["physics_world_step_size"] * self.engine.global_config["decision_repeat"]

angular_velocities = np.array(waypoint_utils.reconstruct_angular_velocity(headings, dt))
velocities = np.array(waypoint_utils.reconstruct_velocity(world_positions, dt))

duration = len(waypoint_positions)
assert duration == self.horizon, "The length of the waypoint positions should be equal to the horizon: {} vs {}".format(
duration, self.horizon
)

self.cache = dict(
position=world_positions,
velocity=velocities,
heading=headings,
angular_velocity=angular_velocities,
)
self.cache_last_update = self.engine.episode_step

assert self.cache is not None

cache_index = self.engine.episode_step - self.cache_last_update
assert cache_index < self.horizon, "Cache index out of range: {} vs {}".format(cache_index, self.horizon)

self.control_object.set_position(self.cache["position"][cache_index])
self.control_object.set_velocity(self.cache["velocity"][cache_index])
self.control_object.set_heading_theta(self.cache["heading"][cache_index])
self.control_object.set_angular_velocity(self.cache["angular_velocity"][cache_index])

# A legacy code to set the static mode of the agent
# If set_static, then the agent will not "fall from the sky".
# However, the physics simulation will not apply too to the agent.
# So in the visualization, the image will be very chunky as the agent will suddenly move to the next
# position for each step.
if self.engine.global_config.get("set_static", False):
self.control_object.set_static(True)

return None # Return None action so the base vehicle will not overwrite the steering & throttle
109 changes: 109 additions & 0 deletions metadrive/tests/test_policy/test_waypoint_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python

import numpy as np

from metadrive.constants import HELP_MESSAGE
from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioWaypointEnv


def test_waypoint_policy(render=False):
"""
Test the waypoint policy by running a scenario in the ScenarioEnv.
"""
asset_path = AssetLoader.asset_path
env = ScenarioWaypointEnv(
{
"sequential_seed": True,
"use_render": False, # True if not args.top_down else False,
"data_directory": AssetLoader.file_path(asset_path, "nuscenes", unix_style=False),
"num_scenarios": 10,
}
)
o, _ = env.reset()
seen_seeds = set()
seen_seeds.add(env.engine.data_manager.current_scenario["id"])
# Load the information
scenario = env.engine.data_manager.current_scenario
ego_id = scenario["metadata"]["sdc_id"]
ego_track = scenario["tracks"][ego_id]
ego_traj = ego_track["state"]["position"][..., :2]
# ego_traj_ego = np.array([env.agent.convert_to_local_coordinates(point, env.agent.position) for point in ego_traj])

waypoint_horizon = env.engine.global_config["waypoint_horizon"]

ADEs, FDEs, flag = [], [], True

episode_step = 0
replay_traj = []
for _ in range(1, 100000):

if episode_step % waypoint_horizon == 0:
# prepare action:
# X-coordinate is the forward direction of the vehicle, Y-coordinate is the left of the vehicle
# Since we start with step 0, the first waypoint should be at step 1. Thus we start from step 1.
local_traj = np.array(
[
env.agent.convert_to_local_coordinates(point, env.agent.position)
for point in ego_traj[episode_step + 1:episode_step + waypoint_horizon + 1]
]
)
action = dict(position=local_traj)
else:
action = None
replay_traj.append(env.agent.position) # Store the pre-step position
o, r, tm, tc, info = env.step(actions=action)
episode_step += 1
if render:
env.render(mode="top_down")
if tm or tc:
if episode_step > 100:
replay_traj = np.array(replay_traj)

# Align their shape
if replay_traj.shape[0] > ego_traj.shape[0]:
replay_traj = replay_traj[:ego_traj.shape[0]]
elif replay_traj.shape[0] < ego_traj.shape[0]:
ego_traj = ego_traj[:replay_traj.shape[0]]

ade = np.mean(np.linalg.norm(replay_traj - ego_traj, axis=-1))
fde = np.linalg.norm(replay_traj[-1] - ego_traj[-1], axis=-1)
ADEs.append(ade)
FDEs.append(fde)
print(
"For seed {}, horizon: {}, ADE: {}, FDE: {}".format(
env.engine.data_manager.current_scenario["id"], episode_step, ade, fde
)
)
else:
# An early terminated episode. Skip.
print(
"Early terminated episode {}, horizon: {}".format(
env.engine.data_manager.current_scenario["id"], episode_step
)
)
pass

episode_step = 0
env.reset()
replay_traj = []
if env.engine.data_manager.current_scenario["id"] in seen_seeds:
break
else:
seen_seeds.add(env.engine.data_manager.current_scenario["id"])
scenario = env.engine.data_manager.current_scenario
ego_id = scenario["metadata"]["sdc_id"]
ego_track = scenario["tracks"][ego_id]
ego_traj = ego_track["state"]["position"][..., :2]

print(f"Mean ADE: {np.mean(ADEs)}, Mean FDE: {np.mean(FDEs)}")

mean_ade = np.mean(ADEs)
mean_fde = np.mean(FDEs)
assert mean_ade < 1e-4 and mean_fde < 1e-4

env.close()


if __name__ == '__main__':
test_waypoint_policy(render=True)
53 changes: 53 additions & 0 deletions metadrive/utils/waypoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np


def reconstruct_heading(waypoints):
"""
Reconstructs the headings based on the waypoints.
return the yaw angle(in world coordinate), with positive value turning left and negative value turning right
"""
# Calculate the headings based on the waypoints
headings = np.arctan2(np.diff(waypoints[:, 1]), np.diff(waypoints[:, 0]))
# Append the last heading to match the length of waypoints
headings = np.append(headings, headings[-1])
return headings


def reconstruct_angular_velocity(headings, time_interval):
"""
Reconstructs the angular velocities based on the headings and time interval.
return in rad/s(in world coordinate), with positive value turning left and negative value turning right
"""
# Calculate the angular velocities
angular_velocities = np.diff(headings) / time_interval
# Append the last angular velocity to match the length of headings
angular_velocities = np.append(angular_velocities, angular_velocities[-1])
return angular_velocities


def reconstruct_velocity(waypoints, dt):
"""
Reconstructs the velocities based on the waypoints and time interval.
"""
diff = np.diff(waypoints, axis=0)
velocitaies = diff / dt
# Append the last velocity to match the length of waypoints
velocitaies = np.append(velocitaies, velocitaies[-1].reshape(1, -1), axis=0)
return velocitaies


def rotate(x, y, angle, z=None, assert_shape=True):
"""
Rotate the coordinates (x, y) by the given angle in radians.
"""
if assert_shape:
assert angle.shape == x.shape == y.shape, (angle.shape, x.shape, y.shape)
if z is not None:
assert x.shape == z.shape
other_x_trans = np.cos(angle) * x - np.sin(angle) * y
other_y_trans = np.cos(angle) * y + np.sin(angle) * x
if z is None:
output_coords = np.stack((other_x_trans, other_y_trans), axis=-1)
else:
output_coords = np.stack((other_x_trans, other_y_trans, z), axis=-1)
return output_coords