@@ -67,32 +67,39 @@ def getScreenGrayscale(self, screen): # pylint: disable=invalid-name
67
67
screen .fill (self .screen_value )
68
68
69
69
70
+ class MockEnvALEWrapper (object ):
71
+ """Mock ALE env wrapper."""
72
+
73
+ def __init__ (self ):
74
+ self .ale = MockALE ()
75
+
76
+
70
77
class MockEnvironment (object ):
71
78
"""Mock environment for testing."""
72
79
73
80
def __init__ (self , screen_size = 10 , max_steps = 10 ):
74
81
self .max_steps = max_steps
75
82
self .screen_size = screen_size
76
- self .ale = MockALE ()
83
+ self .env = MockEnvALEWrapper ()
77
84
self .observation_space = np .empty ((screen_size , screen_size ))
78
85
self .game_over = False
79
86
80
87
def reset (self ):
81
- self .ale .screen_value = 10
88
+ self .env . ale .screen_value = 10
82
89
self .num_steps = 0
83
90
return self .get_observation ()
84
91
85
92
def get_observation (self ):
86
93
observation = np .empty ((self .screen_size , self .screen_size ))
87
- return self .ale .getScreenGrayscale (observation )
94
+ return self .env . ale .getScreenGrayscale (observation )
88
95
89
96
def step (self , action ):
90
97
reward = - 1.0 if action > 0 else 1.0
91
98
self .num_steps += 1
92
99
is_terminal = self .num_steps >= self .max_steps
93
100
94
101
unused = 0
95
- self .ale .screen_value = max (0 , self .ale .screen_value - 2 )
102
+ self .env . ale .screen_value = max (0 , self . env .ale .screen_value - 2 )
96
103
return (self .get_observation (), reward , is_terminal , False , unused )
97
104
98
105
def render (self , mode ):
0 commit comments