@@ -50,7 +50,7 @@ def setUp(self):
5050 self .observation_dtype = dqn_agent .NATURE_DQN_DTYPE
5151 self .stack_size = dqn_agent .NATURE_DQN_STACK_SIZE
5252 self .zero_state = np .zeros (
53- [ 1 , self . observation_shape , self .observation_shape , self .stack_size ] )
53+ ( 1 ,) + self .observation_shape + ( self .stack_size ,) )
5454
5555 def _create_test_agent (self , sess ):
5656 stack_size = self .stack_size
@@ -76,6 +76,9 @@ def _network_template(self, state):
7676
7777 agent = MockDQNAgent (
7878 sess = sess ,
79+ observation_shape = self .observation_shape ,
80+ observation_dtype = self .observation_dtype ,
81+ stack_size = self .stack_size ,
7982 num_actions = self .num_actions ,
8083 min_replay_history = self .min_replay_history ,
8184 epsilon_fn = lambda w , x , y , z : 0.0 , # No exploration.
@@ -108,14 +111,12 @@ def testBeginEpisode(self):
108111 # We fill up the state with 9s. On calling agent.begin_episode the state
109112 # should be reset to all 0s.
110113 agent .state .fill (9 )
111- first_observation = np .ones (
112- [self .observation_shape , self .observation_shape , 1 ])
114+ first_observation = np .ones (self .observation_shape + (1 ,))
113115 self .assertEqual (agent .begin_episode (first_observation ), 0 )
114116 # When the all-1s observation is received, it will be placed at the end of
115117 # the state.
116118 expected_state = self .zero_state
117- expected_state [:, :, :, - 1 ] = np .ones (
118- [1 , self .observation_shape , self .observation_shape ])
119+ expected_state [:, :, :, - 1 ] = np .ones ((1 ,) + self .observation_shape )
119120 self .assertAllEqual (agent .state , expected_state )
120121 self .assertAllEqual (agent ._observation , first_observation [:, :, 0 ])
121122 # No training happens in eval mode.
@@ -126,13 +127,11 @@ def testBeginEpisode(self):
126127 # Having a low replay memory add_count will prevent any of the
127128 # train/prefetch/sync ops from being called.
128129 agent ._replay .memory .add_count = 0
129- second_observation = np .ones (
130- [self .observation_shape , self .observation_shape , 1 ]) * 2
130+ second_observation = np .ones (self .observation_shape + (1 ,)) * 2
131131 agent .begin_episode (second_observation )
132132 # The agent's state will be reset, so we will only be left with the all-2s
133133 # observation.
134- expected_state [:, :, :, - 1 ] = np .full (
135- (1 , self .observation_shape , self .observation_shape ), 2 )
134+ expected_state [:, :, :, - 1 ] = np .full ((1 ,) + self .observation_shape , 2 )
136135 self .assertAllEqual (agent .state , expected_state )
137136 self .assertAllEqual (agent ._observation , second_observation [:, :, 0 ])
138137 # training_steps is incremented since we set eval_mode to False.
@@ -145,8 +144,7 @@ def testStepEval(self):
145144 """
146145 with tf .Session () as sess :
147146 agent = self ._create_test_agent (sess )
148- base_observation = np .ones (
149- [self .observation_shape , self .observation_shape , 1 ])
147+ base_observation = np .ones (self .observation_shape + (1 ,))
150148 # This will reset state and choose a first action.
151149 agent .begin_episode (base_observation )
152150 # We mock the replay buffer to verify how the agent interacts with it.
@@ -163,12 +161,11 @@ def testStepEval(self):
163161 stack_pos = step - num_steps - 1
164162 if stack_pos >= - self .stack_size :
165163 expected_state [:, :, :, stack_pos ] = np .full (
166- (1 , self . observation_shape , self .observation_shape ) , step )
164+ (1 ,) + self .observation_shape , step )
167165 self .assertAllEqual (agent .state , expected_state )
168166 self .assertAllEqual (
169167 agent ._last_observation ,
170- np .ones ([self .observation_shape , self .observation_shape ]) *
171- (num_steps - 1 ))
168+ np .ones (self .observation_shape ) * (num_steps - 1 ))
172169 self .assertAllEqual (agent ._observation , observation [:, :, 0 ])
173170 # No training happens in eval mode.
174171 self .assertEqual (agent .training_steps , 0 )
@@ -183,8 +180,7 @@ def testStepTrain(self):
183180 with tf .Session () as sess :
184181 agent = self ._create_test_agent (sess )
185182 agent .eval_mode = False
186- base_observation = np .ones (
187- [self .observation_shape , self .observation_shape , 1 ])
183+ base_observation = np .ones (self .observation_shape + (1 ,))
188184 # We mock the replay buffer to verify how the agent interacts with it.
189185 agent ._replay = test_utils .MockReplayBuffer ()
190186 self .evaluate (tf .global_variables_initializer ())
@@ -203,7 +199,7 @@ def testStepTrain(self):
203199 stack_pos = step - num_steps - 1
204200 if stack_pos >= - self .stack_size :
205201 expected_state [:, :, :, stack_pos ] = np .full (
206- (1 , self . observation_shape , self .observation_shape ) , step )
202+ (1 ,) + self .observation_shape , step )
207203 self .assertEqual (agent ._replay .add .call_count , step )
208204 mock_args , _ = agent ._replay .add .call_args
209205 self .assertAllEqual (last_observation [:, :, 0 ], mock_args [0 ])
@@ -213,8 +209,7 @@ def testStepTrain(self):
213209 self .assertAllEqual (agent .state , expected_state )
214210 self .assertAllEqual (
215211 agent ._last_observation ,
216- np .full ((self .observation_shape , self .observation_shape ),
217- num_steps - 1 ))
212+ np .full (self .observation_shape , num_steps - 1 ))
218213 self .assertAllEqual (agent ._observation , observation [:, :, 0 ])
219214 # We expect one more than num_steps because of the call to begin_episode.
220215 self .assertEqual (agent .training_steps , num_steps + 1 )
@@ -228,6 +223,78 @@ def testStepTrain(self):
228223 self .assertAllEqual (1 , mock_args [2 ]) # Reward received.
229224 self .assertTrue (mock_args [3 ]) # is_terminal
230225
226+ def testNonTupleObservationShape (self ):
227+ with self .assertRaises (AssertionError ):
228+ self .observation_shape = 84
229+ with tf .Session () as sess :
230+ _ = self ._create_test_agent (sess )
231+
232+ def _testCustomShapes (self , shape , dtype , stack_size ):
233+ self .observation_shape = shape
234+ self .observation_dtype = dtype
235+ self .stack_size = stack_size
236+ self .zero_state = np .zeros ((1 ,) + shape + (stack_size ,))
237+ with tf .Session () as sess :
238+ agent = self ._create_test_agent (sess )
239+ agent .eval_mode = False
240+ base_observation = np .ones (self .observation_shape + (1 ,))
241+ # We mock the replay buffer to verify how the agent interacts with it.
242+ agent ._replay = test_utils .MockReplayBuffer ()
243+ self .evaluate (tf .global_variables_initializer ())
244+ # This will reset state and choose a first action.
245+ agent .begin_episode (base_observation )
246+ observation = base_observation
247+
248+ expected_state = self .zero_state
249+ num_steps = 10
250+ for step in range (1 , num_steps + 1 ):
251+ # We make observation a multiple of step for testing purposes (to
252+ # uniquely identify each observation).
253+ last_observation = observation
254+ observation = base_observation * step
255+ self .assertEqual (agent .step (reward = 1 , observation = observation ), 0 )
256+ stack_pos = step - num_steps - 1
257+ if stack_pos >= - self .stack_size :
258+ expected_state [..., stack_pos ] = np .full (
259+ (1 ,) + self .observation_shape , step )
260+ self .assertEqual (agent ._replay .add .call_count , step )
261+ mock_args , _ = agent ._replay .add .call_args
262+ self .assertAllEqual (last_observation [..., 0 ], mock_args [0 ])
263+ self .assertAllEqual (0 , mock_args [1 ]) # Action selected.
264+ self .assertAllEqual (1 , mock_args [2 ]) # Reward received.
265+ self .assertFalse (mock_args [3 ]) # is_terminal
266+ self .assertAllEqual (agent .state , expected_state )
267+ self .assertAllEqual (
268+ agent ._last_observation ,
269+ np .full (self .observation_shape , num_steps - 1 ))
270+ self .assertAllEqual (agent ._observation , observation [..., 0 ])
271+ # We expect one more than num_steps because of the call to begin_episode.
272+ self .assertEqual (agent .training_steps , num_steps + 1 )
273+ self .assertEqual (agent ._replay .add .call_count , num_steps )
274+
275+ agent .end_episode (reward = 1 )
276+ self .assertEqual (agent ._replay .add .call_count , num_steps + 1 )
277+ mock_args , _ = agent ._replay .add .call_args
278+ self .assertAllEqual (observation [..., 0 ], mock_args [0 ])
279+ self .assertAllEqual (0 , mock_args [1 ]) # Action selected.
280+ self .assertAllEqual (1 , mock_args [2 ]) # Reward received.
281+ self .assertTrue (mock_args [3 ]) # is_terminal
282+
283+ def testStepTrainCustomObservationShapes (self ):
284+ custom_shapes = [(1 ,), (4 , 4 ), (6 , 1 ), (1 , 6 ), (1 , 1 , 6 ), (6 , 6 , 6 , 6 )]
285+ for shape in custom_shapes :
286+ self ._testCustomShapes (shape , tf .uint8 , 1 )
287+
288+ def testStepTrainCustomTypes (self ):
289+ custom_types = [tf .float32 , tf .uint8 , tf .int64 ]
290+ for dtype in custom_types :
291+ self ._testCustomShapes ((4 , 4 ), dtype , 1 )
292+
293+ def testStepTrainCustomStackSizes (self ):
294+ custom_stack_sizes = [1 , 4 , 8 ]
295+ for stack_size in custom_stack_sizes :
296+ self ._testCustomShapes ((4 , 4 ), tf .uint8 , stack_size )
297+
231298 def testLinearlyDecayingEpsilon (self ):
232299 """Test the functionality of the linearly_decaying_epsilon function."""
233300 decay_period = 100
0 commit comments