Skip to content

Commit 687351c

Browse files
Move gym (and gymnasium) env to root and rename gym.py for registration to gym_registration.py (#509)
1 parent 802947e commit 687351c

File tree

8 files changed

+203
-11
lines changed

8 files changed

+203
-11
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,13 @@ changelog = "https://github.com/mgbellemare/Arcade-Learning-Environment/blob/mas
5454
ale-import-roms = "ale_py.scripts.import_roms:main"
5555

5656
[project.entry-points."gym.envs"]
57-
ALE = "ale_py.gym:register_gym_envs"
58-
__internal__ = "ale_py.gym:register_legacy_gym_envs"
57+
ALE = "ale_py.gym_registration:register_gym_envs"
58+
__internal__ = "ale_py.gym_registration:register_legacy_gym_envs"
5959

6060
[tool.setuptools]
6161
packages = [
6262
"ale_py",
6363
"ale_py.roms",
64-
"ale_py.env",
6564
"ale_py.scripts"
6665
]
6766
package-dir = {ale_py = "src/python", gym = "src/gym"}

src/python/env/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import sys
24
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
35

@@ -7,8 +9,7 @@
79
import numpy as np
810

911
import gym
10-
import gym.logger as logger
11-
from gym import error, spaces, utils
12+
from gym import error, spaces, utils, logger
1213

1314
if sys.version_info < (3, 11):
1415
from typing_extensions import NotRequired, TypedDict

src/python/gym_registration.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union
5+
6+
import ale_py.roms as roms
7+
from ale_py.roms import utils as rom_utils
8+
9+
from gym.envs.registration import register
10+
11+
12+
class GymFlavour(NamedTuple):
13+
suffix: str
14+
kwargs: Union[Mapping[Text, Any], Callable[[str], Mapping[Text, Any]]]
15+
16+
17+
class GymConfig(NamedTuple):
18+
version: str
19+
kwargs: Mapping[Text, Any]
20+
flavours: Sequence[GymFlavour]
21+
22+
23+
def _register_gym_configs(
24+
roms: Sequence[str],
25+
obs_types: Sequence[str],
26+
configs: Sequence[GymConfig],
27+
prefix: str = "",
28+
) -> None:
29+
if len(prefix) > 0 and prefix[-1] != "/":
30+
prefix += "/"
31+
32+
for rom in roms:
33+
for obs_type in obs_types:
34+
for config in configs:
35+
for flavour in config.flavours:
36+
name = rom_utils.rom_id_to_name(rom)
37+
name = f"{name}-ram" if obs_type == "ram" else name
38+
39+
# Parse config kwargs
40+
config_kwargs = (
41+
config.kwargs(rom) if callable(config.kwargs) else config.kwargs
42+
)
43+
# Parse flavour kwargs
44+
flavour_kwargs = (
45+
flavour.kwargs(rom)
46+
if callable(flavour.kwargs)
47+
else flavour.kwargs
48+
)
49+
50+
# Register the environment
51+
register(
52+
id=f"{prefix}{name}{flavour.suffix}-{config.version}",
53+
entry_point="ale_py.gym_env:AtariEnv",
54+
kwargs=dict(
55+
game=rom,
56+
obs_type=obs_type,
57+
**config_kwargs,
58+
**flavour_kwargs,
59+
),
60+
)
61+
62+
63+
def register_legacy_gym_envs() -> None:
64+
legacy_games = [
65+
"adventure",
66+
"air_raid",
67+
"alien",
68+
"amidar",
69+
"assault",
70+
"asterix",
71+
"asteroids",
72+
"atlantis",
73+
"bank_heist",
74+
"battle_zone",
75+
"beam_rider",
76+
"berzerk",
77+
"bowling",
78+
"boxing",
79+
"breakout",
80+
"carnival",
81+
"centipede",
82+
"chopper_command",
83+
"crazy_climber",
84+
"defender",
85+
"demon_attack",
86+
"double_dunk",
87+
"elevator_action",
88+
"enduro",
89+
"fishing_derby",
90+
"freeway",
91+
"frostbite",
92+
"gopher",
93+
"gravitar",
94+
"hero",
95+
"ice_hockey",
96+
"jamesbond",
97+
"journey_escape",
98+
"kangaroo",
99+
"krull",
100+
"kung_fu_master",
101+
"montezuma_revenge",
102+
"ms_pacman",
103+
"name_this_game",
104+
"phoenix",
105+
"pitfall",
106+
"pong",
107+
"pooyan",
108+
"private_eye",
109+
"qbert",
110+
"riverraid",
111+
"road_runner",
112+
"robotank",
113+
"seaquest",
114+
"skiing",
115+
"solaris",
116+
"space_invaders",
117+
"star_gunner",
118+
"tennis",
119+
"time_pilot",
120+
"tutankham",
121+
"up_n_down",
122+
"venture",
123+
"video_pinball",
124+
"wizard_of_wor",
125+
"yars_revenge",
126+
"zaxxon",
127+
]
128+
obs_types = ["rgb", "ram"]
129+
frameskip = defaultdict(lambda: 4, [("space_invaders", 3)])
130+
131+
versions = [
132+
GymConfig(
133+
version="v0",
134+
kwargs={
135+
"repeat_action_probability": 0.25,
136+
"full_action_space": False,
137+
"max_num_frames_per_episode": 108_000,
138+
},
139+
flavours=[
140+
# Default for v0 has 10k steps, no idea why...
141+
GymFlavour("", {"frameskip": (2, 5)}),
142+
# Deterministic has 100k steps, close to the standard of 108k (30 mins gameplay)
143+
GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}),
144+
# NoFrameSkip imposes a max episode steps of frameskip * 100k, weird...
145+
GymFlavour("NoFrameskip", {"frameskip": 1}),
146+
],
147+
),
148+
GymConfig(
149+
version="v4",
150+
kwargs={
151+
"repeat_action_probability": 0.0,
152+
"full_action_space": False,
153+
"max_num_frames_per_episode": 108_000,
154+
},
155+
flavours=[
156+
# Unlike v0, v4 has 100k max episode steps
157+
GymFlavour("", {"frameskip": (2, 5)}),
158+
GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}),
159+
# Same weird frameskip * 100k max steps for v4?
160+
GymFlavour("NoFrameskip", {"frameskip": 1}),
161+
],
162+
),
163+
]
164+
165+
_register_gym_configs(legacy_games, obs_types, versions)
166+
167+
168+
def register_gym_envs():
169+
all_games = list(map(rom_utils.rom_name_to_id, dir(roms)))
170+
obs_types = ["rgb", "ram"]
171+
172+
# max_episode_steps is 108k frames which is 30 mins of gameplay.
173+
# This corresponds to 108k / 4 = 27,000 steps
174+
versions = [
175+
GymConfig(
176+
version="v5",
177+
kwargs={
178+
"repeat_action_probability": 0.25,
179+
"full_action_space": False,
180+
"frameskip": 4,
181+
"max_num_frames_per_episode": 108_000,
182+
},
183+
flavours=[GymFlavour("", {})],
184+
)
185+
]
186+
187+
_register_gym_configs(all_games, obs_types, versions)
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
2+
13
from collections import defaultdict
24
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union
35

46
import ale_py.roms as roms
57
from ale_py.roms import utils as rom_utils
68

7-
from gym.envs.registration import register
9+
import gymnasium
810

911

1012
class GymFlavour(NamedTuple):
@@ -46,7 +48,7 @@ def _register_gym_configs(
4648
)
4749

4850
# Register the environment
49-
register(
51+
gymnasium.register(
5052
id=f"{prefix}{name}{flavour.suffix}-{config.version}",
5153
entry_point="ale_py.env.gym:AtariEnv",
5254
kwargs=dict(

tests/fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def tetris_gym(request, test_rom_path):
3333
):
3434
register(
3535
id="TetrisTest-v0",
36-
entry_point="ale_py.env.gym:AtariEnv",
36+
entry_point="ale_py.gym_env:AtariEnv",
3737
kwargs={"game": "tetris_test"},
3838
)
3939

tests/python/gym/test_gym_interface.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from unittest.mock import patch
88

99
import numpy as np
10-
from ale_py.env.gym import AtariEnv
11-
from ale_py.gym import (
10+
from ale_py.gym_env import AtariEnv
11+
from ale_py.gym_registration import (
1212
_register_gym_configs,
1313
register_gym_envs,
1414
register_legacy_gym_envs,
@@ -30,7 +30,10 @@ def test_register_legacy_env_id():
3030
def _mocked_register_gym_configs(*args, **kwargs):
3131
return _original_register_gym_configs(*args, **kwargs, prefix=prefix)
3232

33-
with patch("ale_py.gym._register_gym_configs", new=_mocked_register_gym_configs):
33+
with patch(
34+
"ale_py.gym_registration._register_gym_configs",
35+
new=_mocked_register_gym_configs,
36+
):
3437
# Register internal IDs
3538
register_legacy_gym_envs()
3639

0 commit comments

Comments
 (0)