Skip to content

Commit e8d94c9

Browse files
Waypoint policy (#834)
* Implemented a variant of ReplayEgoCarPolicy called "WaypointPolicy", which allows on-line modification of position, heading, velocity, and angular velocity of the ego car. * Formatted the files * Update the after_step of MultiGoalIntersectionNavigationManager to take args and kwawrgs(without any actual effect though) * Formatting * Fix the argument passing by enforcing a method signature checker for BaseEngine.after_step(). Now, only managers whose "after_step()" takes args and kwargs will have those passed into. * Formatting * Revert certain changes * format * Waypoint policy in ego coordinates implemented. * Updated formulation detail * Test added * Remove uneeded file * Updated ExampleManager * Formatted the files * No cubic splint * Formatted the files * args, kwargs update * args, kwargs update * Maybe relax the mem leak a little bit? * Remove not related code vectorized.py * Roll back a memory leak test * Reformat the waypoint policy unit test, and IT DOESN'T TEST THE CLOSED-LOOP YET! * Implement the WaypointPolicy * Revert some changes * Finish the unit test for WaypointPolicy/Env * Add some TODO. Please fix @Weizhen * Fix bugs * rename and fix --------- Co-authored-by: pengzhenghao <[email protected]>
1 parent 4c8daad commit e8d94c9

File tree

10 files changed

+369
-9
lines changed

10 files changed

+369
-9
lines changed

documentation/source/log_msg.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
" super(MyMgr, self).__init__()\n",
7979
" logger.info(\"Init MyMgr...\")\n",
8080
" \n",
81-
" def after_step(self):\n",
81+
" def after_step(self, *args, **kwargs):\n",
8282
" logger.info(\"Step {}...\".format(self.episode_step))\n",
8383
" return dict()\n",
8484
"\n",

documentation/source/system_design.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@
144144
" self.generated_v.before_step([0.5, 0.4]) # set action\n",
145145
" \n",
146146
"\n",
147-
" def after_step(self):\n",
147+
" def after_step(self, *args, **kwargs):\n",
148148
" if self.episode_step == self.generate_ts:\n",
149149
" self.generated_v = self.spawn_object(DefaultVehicle, \n",
150150
" vehicle_config=dict(), \n",
@@ -494,7 +494,7 @@
494494
" heading=0)\n",
495495
" self.add_policy(obj.id, IDMPolicy, obj, self.generate_seed())\n",
496496
"\n",
497-
" def after_step(self):\n",
497+
" def after_step(self, *args, **kwargs):\n",
498498
" for obj in self.spawned_objects.values():\n",
499499
" obj.after_step()\n",
500500
" if self.episode_step == 180:\n",
@@ -572,7 +572,7 @@
572572
" self.generated_v.before_step([0.5, 0.4]) # set action\n",
573573
" \n",
574574
"\n",
575-
" def after_step(self):\n",
575+
" def after_step(self, *args, **kwargs):\n",
576576
" if self.episode_step == self.generate_ts:\n",
577577
" self.generated_v = self.spawn_object(DefaultVehicle, \n",
578578
" vehicle_config=dict(), \n",

metadrive/envs/multigoal_intersection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def after_reset(self):
236236
navi.reset(self.agent, dest=self.goals[name])
237237
navi.update_localization(self.agent)
238238

239-
def after_step(self):
239+
def after_step(self, *args, **kwargs):
240240
"""Update all navigation modules."""
241241
# print("[DEBUG]: after_step in MultiGoalIntersectionNavigationManager")
242242
for name, navi in self.navigations.items():

metadrive/envs/scenario_env.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from metadrive.manager.scenario_map_manager import ScenarioMapManager
1616
from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager
1717
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
18+
from metadrive.policy.waypoint_policy import WaypointPolicy
1819
from metadrive.utils import get_np_random
1920
from metadrive.utils.math import wrap_to_pi
2021

@@ -63,8 +64,8 @@
6364
),
6465
# If set_static=True, then the agent will not "fall from the sky". This will be helpful if you want to
6566
# capture per-frame data for the agent (for example for collecting static sensor data).
66-
# However, the physics engine will not update the position of the agent. So in the visualization, the image will be
67-
# very chunky as the agent will not suddenly move to the next position for each step.
67+
# However, the physics simulation of the agent will be disable too. So in the visualization, the image will be
68+
# very chunky as the agent will suddenly move to the next position for each step.
6869
# Set to False for better visualization.
6970
set_static=False,
7071

@@ -102,6 +103,16 @@
102103
use_bounding_box=False, # Set True to use a cube in visualization to represent every dynamic objects.
103104
)
104105

106+
SCENARIO_WAYPOINT_ENV_CONFIG = dict(
107+
# How many waypoints will be used at each environmental step. Checkout ScenarioWaypointEnv for details.
108+
waypoint_horizon=5,
109+
agent_policy=WaypointPolicy,
110+
111+
# Must set this to True, otherwise the agent will drift away from the waypoint when doing
112+
# "self.engine.step(self.config["decision_repeat"])" in "_step_simulator".
113+
set_static=True,
114+
)
115+
105116

106117
class ScenarioEnv(BaseEnv):
107118
@classmethod
@@ -418,6 +429,40 @@ def set_scenario(self, scenario_data):
418429
self.engine.data_manager.set_scenario(scenario_data)
419430

420431

432+
class ScenarioWaypointEnv(ScenarioEnv):
433+
"""
434+
This environment use WaypointPolicy. Even though the environment still runs in 10 HZ, we allow the external
435+
waypoint generator generates up to 5 waypoints at each step (controlled by config "waypoint_horizon").
436+
Say at step t, we receive 5 waypoints. Then we will set the agent states for t+1, t+2, t+3, t+4, t+5 if at
437+
t+1 ~ t+4 no additional waypoints are received. Here is the full timeline:
438+
439+
step t=0: env.reset(), initial positions/obs are sent out. This corresponds to the t=0 or t=10 in WOMD dataset
440+
(TODO: we should allow control on the meaning of the t=0)
441+
step t=1: env.step(), agent receives 5 waypoints, we will record the waypoint sequences. Set agent state for t=1,
442+
and send out the obs for t=1.
443+
step t=2: env.step(), it's possible to get action=None, which means the agent will use the cached waypoint t=2,
444+
and set the agent state for t=2. The obs for t=2 will be sent out. If new waypoints are received, we will \
445+
instead set agent state to the first new waypoint.
446+
step t=3: ... continues the loop and receives action=None or new waypoints.
447+
step t=4: ...
448+
step t=5: ...
449+
step t=6: if we only receive action at t=1, and t=2~t=5 are all None, then this step will force to receive
450+
new waypoints. We will set the agent state to the first new waypoint.
451+
452+
Most of the functions are implemented in WaypointPolicy.
453+
"""
454+
@classmethod
455+
def default_config(cls):
456+
config = super(ScenarioWaypointEnv, cls).default_config()
457+
config.update(SCENARIO_WAYPOINT_ENV_CONFIG)
458+
return config
459+
460+
def _post_process_config(self, config):
461+
ret = super(ScenarioWaypointEnv, self)._post_process_config(config)
462+
assert config["set_static"], "Waypoint policy requires set_static=True"
463+
return ret
464+
465+
421466
if __name__ == "__main__":
422467
env = ScenarioEnv(
423468
{
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
This script demonstrates how to use the Waypoint Policy, which feeds (5, 2), that is 5 waypoints, to the ego agent.
3+
The waypoint is in the local frame of the vehicle, where the x-axis is the forward direction of the vehicle and
4+
the y-axis is the left direction of the vehicle.
5+
"""
6+
import argparse
7+
8+
import numpy as np
9+
10+
from metadrive.engine.asset_loader import AssetLoader
11+
from metadrive.envs.scenario_env import ScenarioWaypointEnv
12+
13+
if __name__ == "__main__":
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument("--top_down", "--topdown", action="store_true")
16+
parser.add_argument("--waymo", action="store_true")
17+
args = parser.parse_args()
18+
extra_args = dict(film_size=(2000, 2000)) if args.top_down else {}
19+
asset_path = AssetLoader.asset_path
20+
use_waymo = args.waymo
21+
22+
waypoint_horizon = 5
23+
24+
cfg = {
25+
"map_region_size": 1024, # use a large number if your map is toooooo big
26+
"sequential_seed": True,
27+
"use_render": False,
28+
"data_directory": AssetLoader.file_path(asset_path, "waymo" if use_waymo else "nuscenes", unix_style=False),
29+
"num_scenarios": 3 if use_waymo else 10,
30+
"waypoint_horizon": waypoint_horizon,
31+
}
32+
33+
try:
34+
env = ScenarioWaypointEnv(cfg)
35+
o, _ = env.reset()
36+
i = 0
37+
for _ in range(0, 100000):
38+
if i % waypoint_horizon == 0:
39+
# X-coordinate is the forward direction of the vehicle, Y-coordinate is the left of the vehicle
40+
x_displacement = np.linspace(1, 6, waypoint_horizon)
41+
y_displacement = np.linspace(0, 0.05, waypoint_horizon) # pos y is left.
42+
action = dict(position=np.stack([x_displacement, y_displacement], axis=1))
43+
44+
else:
45+
action = None
46+
47+
o, r, tm, tc, info = env.step(actions=action)
48+
env.render(mode="top_down")
49+
i += 1
50+
if tm or tc:
51+
env.reset()
52+
i = 0
53+
finally:
54+
env.close()

metadrive/policy/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ They are:
1010

1111
* ReplayEgoCarPolicy(BasePolicy): Make the ego car replay the logged trajectory.
1212

13+
* WaypointPolicy(BasePolicy): A policy that follows the waypoint.
14+
1315
Change the `env_config["agent_policy"]` to `IDMPolicy|WaymoIDMPolicy|ReplayEgoCarPolicy` to let the ego car follow different policies.
1416

metadrive/policy/replay_policy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def act(self, *args, **kwargs):
6666

6767
# If set_static, then the agent will not "fall from the sky".
6868
# However, the physics engine will not update the position of the agent.
69-
# So in the visualization, the image will be very chunky as the agent will not suddenly move to the next
69+
# So in the visualization, the image will be very chunky as the agent will suddenly move to the next
7070
# position for each step.
71-
7271
if self.engine.global_config.get("set_static", False):
7372
self.control_object.set_static(True)
7473

metadrive/policy/waypoint_policy.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import gymnasium as gym
2+
import numpy as np
3+
4+
from metadrive.policy.base_policy import BasePolicy
5+
from metadrive.utils import waypoint_utils
6+
7+
8+
class WaypointPolicy(BasePolicy):
9+
"""
10+
This policy will have the trajectory data being overwritten on the fly.
11+
"""
12+
def __init__(self, obj, seed):
13+
super(WaypointPolicy, self).__init__(control_object=obj, random_seed=seed)
14+
self.horizon = self.engine.global_config.get("waypoint_horizon", 10)
15+
self.cache = None
16+
self.cache_last_update = 0
17+
18+
@classmethod
19+
def get_input_space(cls):
20+
from metadrive.engine.engine_utils import get_global_config
21+
horizon = get_global_config().get("waypoint_horizon", 10)
22+
return gym.spaces.Dict(
23+
dict(position=gym.spaces.Box(float("-inf"), float("inf"), shape=(horizon, 2), dtype=np.float32), )
24+
)
25+
26+
def _convert_to_world_coordinates(self, waypoint_positions):
27+
"""
28+
This function is used to convert the waypoint positions from the local frame to the world frame
29+
"""
30+
obj_heading = np.array(self.control_object.heading_theta).reshape(1, ).repeat(waypoint_positions.shape[0])
31+
obj_position = np.array(self.control_object.position).reshape(1, 2)
32+
rotated = waypoint_utils.rotate(
33+
waypoint_positions[:, 0],
34+
waypoint_positions[:, 1],
35+
obj_heading,
36+
)
37+
translated = rotated + obj_position
38+
return translated
39+
40+
def reset(self):
41+
"""
42+
Reset the policy
43+
"""
44+
self.cache = None
45+
self.cache_last_update = 0
46+
super(WaypointPolicy, self).reset()
47+
48+
def act(self, agent_id):
49+
assert self.engine.external_actions is not None
50+
actions = self.engine.external_actions[agent_id]
51+
52+
if actions is not None:
53+
54+
waypoint_positions = actions["position"]
55+
assert waypoint_positions.ndim == 2
56+
assert waypoint_positions.shape[1] == 2
57+
58+
world_positions = self._convert_to_world_coordinates(waypoint_positions)
59+
headings = np.array(waypoint_utils.reconstruct_heading(world_positions))
60+
61+
# dt should be 0.1s in default settings
62+
dt = self.engine.global_config["physics_world_step_size"] * self.engine.global_config["decision_repeat"]
63+
64+
angular_velocities = np.array(waypoint_utils.reconstruct_angular_velocity(headings, dt))
65+
velocities = np.array(waypoint_utils.reconstruct_velocity(world_positions, dt))
66+
67+
duration = len(waypoint_positions)
68+
assert duration == self.horizon, "The length of the waypoint positions should be equal to the horizon: {} vs {}".format(
69+
duration, self.horizon
70+
)
71+
72+
self.cache = dict(
73+
position=world_positions,
74+
velocity=velocities,
75+
heading=headings,
76+
angular_velocity=angular_velocities,
77+
)
78+
self.cache_last_update = self.engine.episode_step
79+
80+
assert self.cache is not None
81+
82+
cache_index = self.engine.episode_step - self.cache_last_update
83+
assert cache_index < self.horizon, "Cache index out of range: {} vs {}".format(cache_index, self.horizon)
84+
85+
self.control_object.set_position(self.cache["position"][cache_index])
86+
self.control_object.set_velocity(self.cache["velocity"][cache_index])
87+
self.control_object.set_heading_theta(self.cache["heading"][cache_index])
88+
self.control_object.set_angular_velocity(self.cache["angular_velocity"][cache_index])
89+
90+
# A legacy code to set the static mode of the agent
91+
# If set_static, then the agent will not "fall from the sky".
92+
# However, the physics simulation will not apply too to the agent.
93+
# So in the visualization, the image will be very chunky as the agent will suddenly move to the next
94+
# position for each step.
95+
if self.engine.global_config.get("set_static", False):
96+
self.control_object.set_static(True)
97+
98+
return None # Return None action so the base vehicle will not overwrite the steering & throttle

0 commit comments

Comments
 (0)