Skip to content

Commit 5a9c4ee

Browse files
committed
Upgrade ALE use for latest version of Gymnasium.
PiperOrigin-RevId: 691514007
1 parent 02a3547 commit 5a9c4ee

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

dopamine/discrete_domains/atari_lib.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def reset(self):
553553
environment.
554554
"""
555555
self.environment.reset()
556-
self.lives = self.environment.ale.lives()
556+
self.lives = self.environment.env.ale.lives()
557557
self._fetch_grayscale_observation(self.screen_buffer[0])
558558
self.screen_buffer[1].fill(0)
559559
return self._pool_and_resize()
@@ -608,7 +608,7 @@ def step(self, action):
608608
accumulated_reward += reward
609609

610610
if self.terminal_on_life_loss:
611-
new_lives = self.environment.ale.lives()
611+
new_lives = self.environment.env.ale.lives()
612612
is_terminal = game_over or new_lives < self.lives
613613
self.lives = new_lives
614614
else:
@@ -639,7 +639,7 @@ def _fetch_grayscale_observation(self, output):
639639
Returns:
640640
observation: numpy array, the current observation in grayscale.
641641
"""
642-
self.environment.ale.getScreenGrayscale(output)
642+
self.environment.env.ale.getScreenGrayscale(output)
643643
return output
644644

645645
def _pool_and_resize(self):

tests/dopamine/discrete_domains/atari_lib_test.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -67,32 +67,39 @@ def getScreenGrayscale(self, screen): # pylint: disable=invalid-name
6767
screen.fill(self.screen_value)
6868

6969

70+
class MockEnvALEWrapper(object):
71+
"""Mock ALE env wrapper."""
72+
73+
def __init__(self):
74+
self.ale = MockALE()
75+
76+
7077
class MockEnvironment(object):
7178
"""Mock environment for testing."""
7279

7380
def __init__(self, screen_size=10, max_steps=10):
7481
self.max_steps = max_steps
7582
self.screen_size = screen_size
76-
self.ale = MockALE()
83+
self.env = MockEnvALEWrapper()
7784
self.observation_space = np.empty((screen_size, screen_size))
7885
self.game_over = False
7986

8087
def reset(self):
81-
self.ale.screen_value = 10
88+
self.env.ale.screen_value = 10
8289
self.num_steps = 0
8390
return self.get_observation()
8491

8592
def get_observation(self):
8693
observation = np.empty((self.screen_size, self.screen_size))
87-
return self.ale.getScreenGrayscale(observation)
94+
return self.env.ale.getScreenGrayscale(observation)
8895

8996
def step(self, action):
9097
reward = -1.0 if action > 0 else 1.0
9198
self.num_steps += 1
9299
is_terminal = self.num_steps >= self.max_steps
93100

94101
unused = 0
95-
self.ale.screen_value = max(0, self.ale.screen_value - 2)
102+
self.env.ale.screen_value = max(0, self.env.ale.screen_value - 2)
96103
return (self.get_observation(), reward, is_terminal, False, unused)
97104

98105
def render(self, mode):

0 commit comments

Comments
 (0)