Skip to content

Commit 4c8daad

Browse files
A basic traffic mode for PGTrafficManager (#827)
* Implemeted a basic traffic mode * Format
1 parent d578dc8 commit 4c8daad

File tree

2 files changed

+79
-29
lines changed

2 files changed

+79
-29
lines changed

documentation/source/rl_environments.ipynb

Lines changed: 53 additions & 7 deletions
Large diffs are not rendered by default.

metadrive/manager/traffic_manager.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818

1919

2020
class TrafficMode:
21+
# Traffic vehicles will be spawned once
22+
Basic = "basic"
23+
2124
# Traffic vehicles will be respawned, once they arrive at the destinations
2225
Respawn = "respawn"
2326

24-
# Traffic vehicles will be triggered only once
27+
# Traffic vehicles will be triggered only once, and will be triggered when agent comes close
2528
Trigger = "trigger"
2629

2730
# Hybrid, some vehicles are triggered once on map and disappear when arriving at destination, others exist all time
@@ -63,13 +66,15 @@ def reset(self):
6366
if abs(traffic_density) < 1e-2:
6467
return
6568
self.respawn_lanes = self._get_available_respawn_lanes(map)
66-
if self.mode == TrafficMode.Respawn:
67-
# add respawn vehicle
68-
self._create_respawn_vehicles(map, traffic_density)
69-
elif self.mode == TrafficMode.Trigger or self.mode == TrafficMode.Hybrid:
70-
self._create_vehicles_once(map, traffic_density)
69+
70+
logging.debug(f"Resetting Traffic Manager with mode {self.mode} and density {traffic_density}")
71+
72+
if self.mode in {TrafficMode.Basic, TrafficMode.Respawn}:
73+
self._create_basic_vehicles(map, traffic_density)
74+
elif self.mode in {TrafficMode.Trigger, TrafficMode.Hybrid}:
75+
self._create_trigger_vehicles(map, traffic_density)
7176
else:
72-
raise ValueError("No such mode named {}".format(self.mode))
77+
raise ValueError(f"No such mode named {self.mode}")
7378

7479
def before_step(self):
7580
"""
@@ -78,14 +83,15 @@ def before_step(self):
7883
"""
7984
# trigger vehicles
8085
engine = self.engine
81-
if self.mode != TrafficMode.Respawn:
86+
if self.mode in {TrafficMode.Trigger, TrafficMode.Hybrid}:
8287
for v in engine.agent_manager.active_agents.values():
8388
if len(self.block_triggered_vehicles) > 0:
8489
ego_lane_idx = v.lane_index[:-1]
8590
ego_road = Road(ego_lane_idx[0], ego_lane_idx[1])
8691
if ego_road == self.block_triggered_vehicles[-1].trigger_road:
8792
block_vehicles = self.block_triggered_vehicles.pop()
8893
self._traffic_vehicles += list(self.get_objects(block_vehicles.vehicles).values())
94+
8995
for v in self._traffic_vehicles:
9096
p = self.engine.get_policy(v.name)
9197
v.before_step(p.act())
@@ -99,17 +105,15 @@ def after_step(self, *args, **kwargs):
99105
for v in self._traffic_vehicles:
100106
v.after_step()
101107
if not v.on_lane:
102-
if self.mode == TrafficMode.Trigger:
103-
v_to_remove.append(v)
104-
elif self.mode == TrafficMode.Respawn or self.mode == TrafficMode.Hybrid:
105-
v_to_remove.append(v)
106-
else:
107-
raise ValueError("Traffic mode error: {}".format(self.mode))
108+
v_to_remove.append(v)
109+
108110
for v in v_to_remove:
109111
vehicle_type = type(v)
110112
self.clear_objects([v.id])
111113
self._traffic_vehicles.remove(v)
112-
if self.mode == TrafficMode.Respawn or self.mode == TrafficMode.Hybrid:
114+
115+
# Spawn new vehicles to replace the removed one
116+
if self.mode in {TrafficMode.Respawn, TrafficMode.Hybrid}:
113117
lane = self.respawn_lanes[self.np_random.randint(0, len(self.respawn_lanes))]
114118
lane_idx = lane.index
115119
long = self.np_random.rand() * lane.length / 2
@@ -136,7 +140,7 @@ def get_vehicle_num(self):
136140
Get the vehicles on road
137141
:return:
138142
"""
139-
if self.mode == TrafficMode.Respawn:
143+
if self.mode in {TrafficMode.Basic, TrafficMode.Respawn}:
140144
return len(self._traffic_vehicles)
141145
return sum(len(block_vehicle_set.vehicles) for block_vehicle_set in self.block_triggered_vehicles)
142146

@@ -151,7 +155,7 @@ def get_global_states(self) -> Dict:
151155
traffic_states[vehicle.index] = vehicle.get_state()
152156

153157
# collect other vehicles
154-
if self.mode != TrafficMode.Respawn:
158+
if self.mode in {TrafficMode.Trigger, TrafficMode.Hybrid}:
155159
for v_b in self.block_triggered_vehicles:
156160
for vehicle in v_b.vehicles:
157161
traffic_states[vehicle.index] = vehicle.get_state()
@@ -188,7 +192,7 @@ def get_global_init_states(self) -> Dict:
188192
vehicles[vehicle.index] = init_state
189193

190194
# collect other vehicles
191-
if self.mode != TrafficMode.Respawn:
195+
if self.mode in {TrafficMode.Trigger, TrafficMode.Hybrid}:
192196
for v_b in self.block_triggered_vehicles:
193197
for vehicle in v_b.vehicles:
194198
init_state = vehicle.get_state()
@@ -208,7 +212,7 @@ def _propose_vehicle_configs(self, lane: AbstractLane):
208212
potential_vehicle_configs.append(random_vehicle_config)
209213
return potential_vehicle_configs
210214

211-
def _create_respawn_vehicles(self, map: BaseMap, traffic_density: float):
215+
def _create_basic_vehicles(self, map: BaseMap, traffic_density: float):
212216
total_num = len(self.respawn_lanes)
213217
for lane in self.respawn_lanes:
214218
_traffic_vehicles = []
@@ -227,7 +231,7 @@ def _create_respawn_vehicles(self, map: BaseMap, traffic_density: float):
227231
self.add_policy(random_v.id, IDMPolicy, random_v, self.generate_seed())
228232
self._traffic_vehicles.append(random_v)
229233

230-
def _create_vehicles_once(self, map: BaseMap, traffic_density: float) -> None:
234+
def _create_trigger_vehicles(self, map: BaseMap, traffic_density: float) -> None:
231235
"""
232236
Trigger mode, vehicles will be triggered only once, and disappear when arriving destination
233237
:param map: Map
@@ -365,10 +369,10 @@ def set_state(self, state: dict, old_name_to_current=None):
365369

366370

367371
class MixedPGTrafficManager(PGTrafficManager):
368-
def _create_respawn_vehicles(self, *args, **kwargs):
372+
def _create_basic_vehicles(self, *args, **kwargs):
369373
raise NotImplementedError()
370374

371-
def _create_vehicles_once(self, map: BaseMap, traffic_density: float) -> None:
375+
def _create_trigger_vehicles(self, map: BaseMap, traffic_density: float) -> None:
372376
vehicle_num = 0
373377
for block in map.blocks[1:]:
374378

0 commit comments

Comments
 (0)