Skip to content

Commit e9a22c0

Browse files
committed
Add agent communcation example.
1 parent b7dc02e commit e9a22c0

File tree

2 files changed

+228
-11
lines changed

2 files changed

+228
-11
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import sys
2+
from pathlib import Path
3+
from typing import Any, Dict, Union
4+
5+
from smarts.core.agent import Agent
6+
from smarts.core.agent_interface import AgentInterface, AgentType
7+
from smarts.core.utils.episodes import episodes
8+
from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1
9+
from smarts.env.gymnasium.wrappers.agent_communication import (
10+
Bands,
11+
Header,
12+
Message,
13+
MessagePasser,
14+
V2XReceiver,
15+
V2XTransmitter,
16+
)
17+
from smarts.env.utils.action_conversion import ActionOptions
18+
from smarts.env.utils.observation_conversion import ObservationOptions
19+
from smarts.sstudio.scenario_construction import build_scenarios
20+
21+
sys.path.insert(0, str(Path(__file__).parents[2].absolute()))
22+
import gymnasium as gym
23+
24+
from examples.tools.argument_parser import default_argument_parser
25+
26+
TIMESTEP = 0.1
27+
BYTES_IN_MEGABIT = 125000
28+
MESSAGE_MEGABITS_PER_SECOND = 10
29+
MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP)
30+
31+
32+
def filter_useless(transmissions):
33+
for header, msg in transmissions:
34+
if header.sender in ("parked_agent", "broken_stoplight"):
35+
continue
36+
if header.sender_type in ("advertisement",):
37+
continue
38+
yield header, msg
39+
40+
41+
class LaneFollowerAgent(Agent):
42+
def act(self, obs: Dict[Any, Union[Any, Dict]]):
43+
return (obs["waypoint_paths"]["speed_limit"][0][0], 0)
44+
45+
46+
class GossiperAgent(Agent):
47+
def __init__(self, id_: str, base_agent: Agent, filter_, friends):
48+
self._filter = filter_
49+
self._id = id_
50+
self._friends = friends
51+
self._base_agent = base_agent
52+
53+
def act(self, obs, **configs):
54+
out_transmissions = []
55+
for header, msg in self._filter(obs["transmissions"]):
56+
header: Header = header
57+
msg: Message = msg
58+
if not {self._id, "__all__"}.intersection(header.cc | header.bcc):
59+
continue
60+
if header.channel == "position_request":
61+
print()
62+
print("On step: ", obs["steps_completed"])
63+
print("Gossiper received position request: ", header)
64+
out_transmissions.append(
65+
(
66+
Header(
67+
channel="position",
68+
sender=self._id,
69+
sender_type="ad_vehicle",
70+
cc={header.sender},
71+
bcc={*self._friends},
72+
format="position",
73+
), # optimize this later
74+
Message(
75+
content=obs["ego_vehicle_state"]["position"],
76+
), # optimize this later
77+
)
78+
)
79+
print("Gossiper sent position: ", out_transmissions[0][1])
80+
81+
base_action = self._base_agent.act(obs)
82+
return (base_action, out_transmissions)
83+
84+
85+
class SchemerAgent(Agent):
86+
def __init__(self, id_: str, base_agent: Agent, request_freq) -> None:
87+
self._base_agent = base_agent
88+
self._id = id_
89+
self._request_freq = request_freq
90+
91+
def act(self, obs, **configs):
92+
out_transmissions = []
93+
for header, msg in obs["transmissions"]:
94+
header: Header = header
95+
msg: Message = msg
96+
if header.channel == "position":
97+
print()
98+
print("On step: ", obs["steps_completed"])
99+
print("Schemer received position: ", msg)
100+
101+
if obs["steps_completed"] % self._request_freq == 0:
102+
print()
103+
print("On step: ", obs["steps_completed"])
104+
out_transmissions.append(
105+
(
106+
Header(
107+
channel="position_request",
108+
sender=self._id,
109+
sender_type="ad_vehicle",
110+
cc=set(),
111+
bcc={"__all__"},
112+
format="position_request",
113+
),
114+
Message(content=None),
115+
)
116+
)
117+
print("Schemer requested position with: ", out_transmissions[0][0])
118+
119+
base_action = self._base_agent.act(obs)
120+
return (base_action, out_transmissions)
121+
122+
123+
def main(scenarios, headless, num_episodes, max_episode_steps=None):
124+
agent_interface = AgentInterface.from_type(
125+
AgentType.LanerWithSpeed, max_episode_steps=max_episode_steps
126+
)
127+
hiwayv1env = HiWayEnvV1(
128+
scenarios=scenarios,
129+
agent_interfaces={"gossiper0": agent_interface, "schemer": agent_interface},
130+
headless=headless,
131+
observation_options=ObservationOptions.multi_agent,
132+
action_options=ActionOptions.default,
133+
)
134+
# for now
135+
env = MessagePasser(
136+
hiwayv1env,
137+
max_message_bytes=MESSAGE_BYTES,
138+
message_config={
139+
"gossiper0": (
140+
V2XTransmitter(
141+
bands=Bands.ALL,
142+
range=100,
143+
# available_channels=["position_request", "position"]
144+
),
145+
V2XReceiver(
146+
bands=Bands.ALL,
147+
aliases=["tim"],
148+
blacklist_channels={"self_control"},
149+
),
150+
),
151+
"schemer": (
152+
V2XTransmitter(
153+
bands=Bands.ALL,
154+
range=100,
155+
),
156+
V2XReceiver(
157+
bands=Bands.ALL,
158+
aliases=[],
159+
),
160+
),
161+
},
162+
)
163+
agents = {
164+
"gossiper0": GossiperAgent(
165+
"gossiper0",
166+
base_agent=LaneFollowerAgent(),
167+
filter_=filter_useless,
168+
friends={"schemer"},
169+
),
170+
"schemer": SchemerAgent(
171+
"schemer", base_agent=LaneFollowerAgent(), request_freq=100
172+
),
173+
}
174+
175+
# then just the standard gym interface with no modifications
176+
for episode in episodes(n=num_episodes):
177+
observation, info = env.reset()
178+
episode.record_scenario(env.scenario_log)
179+
180+
terminated = {"__all__": False}
181+
while not terminated["__all__"]:
182+
agent_action = {
183+
agent_id: agents[agent_id].act(obs)
184+
for agent_id, obs in observation.items()
185+
}
186+
observation, reward, terminated, truncated, info = env.step(agent_action)
187+
episode.record_step(observation, reward, terminated, info)
188+
189+
env.close()
190+
191+
192+
if __name__ == "__main__":
193+
parser = default_argument_parser("single-agent-example")
194+
args = parser.parse_args()
195+
196+
if not args.scenarios:
197+
args.scenarios = [
198+
str(Path(__file__).absolute().parents[2] / "scenarios" / "sumo" / "loop")
199+
]
200+
201+
build_scenarios(scenarios=args.scenarios)
202+
203+
main(
204+
scenarios=args.scenarios,
205+
headless=args.headless,
206+
num_episodes=args.episodes,
207+
)

smarts/env/gymnasium/wrappers/agent_communication.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,17 @@ def __init__(
128128
for a_id, (_, receiver) in message_config.items():
129129
for alias in receiver.aliases:
130130
self._alias_mapping[alias].append(a_id)
131+
self._alias_mapping[a_id].append(a_id)
132+
self._alias_mapping["__all__"].append(a_id)
131133

132-
assert isinstance(env, HiWayEnvV1)
134+
assert isinstance(env.unwrapped, HiWayEnvV1)
133135
o_action_space: gym.spaces.Dict = self.env.action_space
134-
msg_space = (
135-
gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8),
136+
msg_space = gym.spaces.Tuple(
137+
(
138+
gym.spaces.Box(
139+
low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8
140+
),
141+
)
136142
)
137143
self.action_space = gym.spaces.Dict(
138144
{
@@ -182,19 +188,21 @@ def resolve_alias(self, alias):
182188

183189
def step(self, action):
184190
"""Steps the environment using the given action."""
185-
std_actions = {a_id: act for a_id, (act, _) in action}
191+
std_actions = {a_id: act for a_id, (act, _) in action.items()}
186192
observations, rewards, terms, truncs, infos = self.env.step(std_actions)
187193

188194
msgs = defaultdict(list)
189195

190196
# pytype: disable=wrong-arg-types
191197
# filter recipients for active
192-
cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys()))
198+
cached_active_filter = lru_cache(
199+
lambda a: frozenset(a.intersection(observations.keys()))
200+
)
193201

194202
# filter recipients by band
195203
## compare transmitter
196204
cached_band_filter = lru_cache(
197-
lambda sender, recipients: (
205+
lambda sender, recipients: frozenset(
198206
r
199207
for r in recipients
200208
if self._message_config[sender][0].bands
@@ -211,7 +219,7 @@ def step(self, action):
211219
and channel not in self._message_config[recipient][1].blacklist_channels
212220
)
213221
cached_channel_filter = lru_cache(
214-
lambda channel, recipients: (
222+
lambda channel, recipients: frozenset(
215223
r for r in recipients if accepts_channel(channel, r)
216224
)
217225
)
@@ -231,7 +239,9 @@ def step(self, action):
231239
for recipients in map(self.resolve_alias, initial_recipients)
232240
for cc in cached_channel_filter(
233241
header.channel,
234-
cached_band_filter(header.sender, cached_active_filter(recipients)),
242+
cached_band_filter(
243+
header.sender, cached_active_filter(frozenset(recipients))
244+
),
235245
)
236246
)
237247

@@ -243,8 +253,8 @@ def step(self, action):
243253
message: Message = message
244254

245255
# expand the recipients
246-
cc_recipients = set(general_filter(header, header.cc))
247-
bcc_recipients = set(general_filter(header, header.bcc))
256+
cc_recipients = set(general_filter(header, frozenset(header.cc)))
257+
bcc_recipients = set(general_filter(header, frozenset(header.bcc)))
248258
cc_header = header._replace(cc=cc_recipients)
249259

250260
# associate the messages to the recipients
@@ -279,7 +289,7 @@ def reset(
279289
"""Resets the environment."""
280290
observations, info = super().reset(seed=seed, options=options)
281291
obs_with_msgs = {
282-
a_id: dict(**obs, transmissions=self._transmission_space.sample(0))
292+
a_id: dict(**obs, transmissions=self._transmission_space.sample((0, ())))
283293
for a_id, obs in observations.items()
284294
}
285295
return obs_with_msgs, info

0 commit comments

Comments
 (0)