Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Stream multiple games of pokemon onto a shared map!

## To broadcast:
Use the StreamWrapper gym environment wrapper:
https://github.com/PWhiddy/PokemonRedExperiments/blob/master/baselines/stream_agent_wrapper.py
https://github.com/PWhiddy/pokerl-map-viz/blob/main/examples/stream_agent_wrapper.py

And wrap your environment like this:
```python
Expand Down
22 changes: 22 additions & 0 deletions example/implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

from gymnasium import Env
from stream_agent_wrapper import StreamWrapper

def make_env():
def _init():

env = Env() # Replace with your own environment
env = StreamWrapper(
env,
stream_metadata = {
# All of this is part is optional
"user": "pw", # choose your own username
"env_id": id, # environment identifier
"color": "#0033ff", # choose your color :)
"extra": "", # any extra text you put here will be displayed
}
)
return env

return _init

67 changes: 67 additions & 0 deletions example/stream_agent_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import asyncio
import websockets
import json
import gymnasium as gym

X_POS_ADDRESS, Y_POS_ADDRESS = 0xD362, 0xD361 # Memory Addresses for X and Y position in Pokemon Gen 1 games
MAP_N_ADDRESS = 0xD35E # Memory Address for map number in Pokemon Gen 1 games

class StreamWrapper(gym.Wrapper):
def __init__(self, env, stream_metadata={}):
super().__init__(env)
self.ws_address = "wss://poke-ws-test-ulsjzjzwpa-ue.a.run.app/broadcast"
self.stream_metadata = stream_metadata
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.websocket = self.loop.run_until_complete(
self.establish_wc_connection()
)
self.upload_interval = 200 # How many steps between each upload
self.stream_step_counter = 0
self.coord_list = []
if hasattr(env, "pyboy"):
self.emulator = env.pyboy
elif hasattr(env, "game"):
self.emulator = env.game
else:
raise Exception("Could not find emulator!")

def step(self, action):

x_pos = self.emulator.get_memory_value(X_POS_ADDRESS)
y_pos = self.emulator.get_memory_value(Y_POS_ADDRESS)
map_n = self.emulator.get_memory_value(MAP_N_ADDRESS)
self.coord_list.append([x_pos, y_pos, map_n])

if self.stream_step_counter >= self.upload_interval:
self.loop.run_until_complete(
self.broadcast_ws_message(
json.dumps(
{
"metadata": self.stream_metadata,
"coords": self.coord_list
}
)
)
)
self.stream_step_counter = 0
self.coord_list = []

self.steam_step_counter += 1

return self.env.step(action)

async def broadcast_ws_message(self, message):
if self.websocket is None:
await self.establish_wc_connection()
if self.websocket is not None:
try:
await self.websocket.send(message)
except websockets.exceptions.WebSocketException as e:
self.websocket = None

async def establish_wc_connection(self):
try:
self.websocket = await websockets.connect(self.ws_address)
except:
self.websocket = None