Skip to content

Commit bc66570

Browse files
committed
Fix the handling of custom observation shapes and types. This includes enforcing shapes to be passed in as tuples.
PiperOrigin-RevId: 222560771
1 parent 06d64fc commit bc66570

File tree

8 files changed

+153
-102
lines changed

8 files changed

+153
-102
lines changed

dopamine/agents/dqn/dqn_agent.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
slim = tf.contrib.slim
3535

3636

37-
NATURE_DQN_OBSERVATION_SHAPE = 84 # Size of a downscaled Atari 2600 frame.
37+
NATURE_DQN_OBSERVATION_SHAPE = (84, 84) # Size of downscaled Atari 2600 frame.
3838
NATURE_DQN_DTYPE = tf.uint8 # DType of Atari 2600 observations.
3939
NATURE_DQN_STACK_SIZE = 4 # Number of frames in the state stack.
4040

@@ -98,8 +98,7 @@ def __init__(self,
9898
Args:
9999
sess: `tf.Session`, for executing ops.
100100
num_actions: int, number of actions the agent can take at any state.
101-
observation_shape: tuple of ints or an int. If single int, the observation
102-
is assumed to be a 2D square.
101+
observation_shape: tuple of ints describing the observation shape.
103102
observation_dtype: tf.DType, specifies the type of the observations. Note
104103
that if your inputs are continuous, you should set this to tf.float32.
105104
stack_size: int, number of frames to use in state stack.
@@ -128,7 +127,7 @@ def __init__(self,
128127
summary_writing_frequency: int, frequency with which summaries will be
129128
written. Lower values will result in slower training.
130129
"""
131-
130+
assert isinstance(observation_shape, tuple)
132131
tf.logging.info('Creating %s agent with the following parameters:',
133132
self.__class__.__name__)
134133
tf.logging.info('\t gamma: %f', gamma)
@@ -144,11 +143,8 @@ def __init__(self,
144143
tf.logging.info('\t optimizer: %s', optimizer)
145144

146145
self.num_actions = num_actions
147-
if (isinstance(observation_shape, tuple) or
148-
isinstance(observation_shape, list)):
149-
self.observation_shape = tuple(observation_shape)
150-
else:
151-
self.observation_shape = (observation_shape, observation_shape)
146+
self.observation_shape = tuple(observation_shape)
147+
self.observation_dtype = observation_dtype
152148
self.stack_size = stack_size
153149
self.gamma = gamma
154150
self.update_horizon = update_horizon
@@ -171,7 +167,7 @@ def __init__(self,
171167
# The last axis indicates the number of consecutive frames stacked.
172168
state_shape = (1,) + self.observation_shape + (stack_size,)
173169
self.state = np.zeros(state_shape)
174-
self.state_ph = tf.placeholder(observation_dtype, state_shape,
170+
self.state_ph = tf.placeholder(self.observation_dtype, state_shape,
175171
name='state_ph')
176172
self._replay = self._build_replay_buffer(use_staging)
177173

@@ -260,7 +256,8 @@ def _build_replay_buffer(self, use_staging):
260256
stack_size=self.stack_size,
261257
use_staging=use_staging,
262258
update_horizon=self.update_horizon,
263-
gamma=self.gamma)
259+
gamma=self.gamma,
260+
observation_dtype=self.observation_dtype.as_numpy_dtype)
264261

265262
def _build_target_q_op(self):
266263
"""Build an op used as a target for the Q-value.
@@ -428,11 +425,14 @@ def _record_observation(self, observation):
428425
Args:
429426
observation: numpy array, an observation from the environment.
430427
"""
431-
# Set current observation. Represents an 84 x 84 x 1 image frame.
432-
self._observation = observation[:, :, 0]
428+
# Set current observation. We do the reshaping to handle environments
429+
# without frame stacking.
430+
observation = np.reshape(observation, self.observation_shape)
431+
self._observation = observation[..., 0]
432+
self._observation = np.reshape(observation, self.observation_shape)
433433
# Swap out the oldest frame with the current frame.
434-
self.state = np.roll(self.state, -1, axis=3)
435-
self.state[0, :, :, -1] = self._observation
434+
self.state = np.roll(self.state, -1, axis=-1)
435+
self.state[0, ..., -1] = self._observation
436436

437437
def _store_transition(self, last_observation, action, reward, is_terminal):
438438
"""Stores an experienced transition.

dopamine/replay_memory/circular_replay_buffer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def __init__(self,
107107
"""Initializes OutOfGraphReplayBuffer.
108108
109109
Args:
110-
observation_shape: tuple or int. If int, the observation is
111-
assumed to be a 2D square.
110+
observation_shape: tuple of ints.
112111
stack_size: int, number of frames to use in state stack.
113112
replay_capacity: int, number of transitions to keep in memory.
114113
batch_size: int.
@@ -125,6 +124,7 @@ def __init__(self,
125124
ValueError: If replay_capacity is too small to hold at least one
126125
transition.
127126
"""
127+
assert isinstance(observation_shape, tuple)
128128
if replay_capacity < update_horizon + stack_size:
129129
raise ValueError('There is not enough capacity to cover '
130130
'update_horizon and stack_size.')
@@ -133,16 +133,14 @@ def __init__(self,
133133
'Creating a %s replay memory with the following parameters:',
134134
self.__class__.__name__)
135135
tf.logging.info('\t observation_shape: %s', str(observation_shape))
136+
tf.logging.info('\t observation_dtype: %s', str(observation_dtype))
136137
tf.logging.info('\t stack_size: %d', stack_size)
137138
tf.logging.info('\t replay_capacity: %d', replay_capacity)
138139
tf.logging.info('\t batch_size: %d', batch_size)
139140
tf.logging.info('\t update_horizon: %d', update_horizon)
140141
tf.logging.info('\t gamma: %f', gamma)
141142

142-
if isinstance(observation_shape, tuple):
143-
self._observation_shape = observation_shape
144-
else:
145-
self._observation_shape = (observation_shape, observation_shape)
143+
self._observation_shape = observation_shape
146144
self._stack_size = stack_size
147145
self._state_shape = self._observation_shape + (self._stack_size,)
148146
self._replay_capacity = replay_capacity
@@ -663,8 +661,7 @@ def __init__(self,
663661
"""Initializes WrappedReplayBuffer.
664662
665663
Args:
666-
observation_shape: tuple or int. If int, the observation is
667-
assumed to be a 2D square.
664+
observation_shape: tuple of ints.
668665
stack_size: int, number of frames to use in state stack.
669666
use_staging: bool, when True it would use a staging area to prefetch
670667
the next sampling batch.

dopamine/replay_memory/prioritized_replay_buffer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def __init__(self,
5353
"""Initializes OutOfGraphPrioritizedReplayBuffer.
5454
5555
Args:
56-
observation_shape: tuple or int. If int, the observation is
57-
assumed to be a 2D square with sides equal to observation_shape.
56+
observation_shape: tuple of ints.
5857
stack_size: int, number of frames to use in state stack.
5958
replay_capacity: int, number of transitions to keep in memory.
6059
batch_size: int.
@@ -264,8 +263,7 @@ def __init__(self,
264263
"""Initializes WrappedPrioritizedReplayBuffer.
265264
266265
Args:
267-
observation_shape: tuple or int. If int, the observation is
268-
assumed to be a 2D square with sides equal to observation_shape.
266+
observation_shape: tuple of ints.
269267
stack_size: int, number of frames to use in state stack.
270268
use_staging: bool, when True it would use a staging area to prefetch
271269
the next sampling batch.

tests/dopamine/agents/dqn/dqn_agent_test.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def setUp(self):
5050
self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE
5151
self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE
5252
self.zero_state = np.zeros(
53-
[1, self.observation_shape, self.observation_shape, self.stack_size])
53+
(1,) + self.observation_shape + (self.stack_size,))
5454

5555
def _create_test_agent(self, sess):
5656
stack_size = self.stack_size
@@ -76,6 +76,9 @@ def _network_template(self, state):
7676

7777
agent = MockDQNAgent(
7878
sess=sess,
79+
observation_shape=self.observation_shape,
80+
observation_dtype=self.observation_dtype,
81+
stack_size=self.stack_size,
7982
num_actions=self.num_actions,
8083
min_replay_history=self.min_replay_history,
8184
epsilon_fn=lambda w, x, y, z: 0.0, # No exploration.
@@ -108,14 +111,12 @@ def testBeginEpisode(self):
108111
# We fill up the state with 9s. On calling agent.begin_episode the state
109112
# should be reset to all 0s.
110113
agent.state.fill(9)
111-
first_observation = np.ones(
112-
[self.observation_shape, self.observation_shape, 1])
114+
first_observation = np.ones(self.observation_shape + (1,))
113115
self.assertEqual(agent.begin_episode(first_observation), 0)
114116
# When the all-1s observation is received, it will be placed at the end of
115117
# the state.
116118
expected_state = self.zero_state
117-
expected_state[:, :, :, -1] = np.ones(
118-
[1, self.observation_shape, self.observation_shape])
119+
expected_state[:, :, :, -1] = np.ones((1,) + self.observation_shape)
119120
self.assertAllEqual(agent.state, expected_state)
120121
self.assertAllEqual(agent._observation, first_observation[:, :, 0])
121122
# No training happens in eval mode.
@@ -126,13 +127,11 @@ def testBeginEpisode(self):
126127
# Having a low replay memory add_count will prevent any of the
127128
# train/prefetch/sync ops from being called.
128129
agent._replay.memory.add_count = 0
129-
second_observation = np.ones(
130-
[self.observation_shape, self.observation_shape, 1]) * 2
130+
second_observation = np.ones(self.observation_shape + (1,)) * 2
131131
agent.begin_episode(second_observation)
132132
# The agent's state will be reset, so we will only be left with the all-2s
133133
# observation.
134-
expected_state[:, :, :, -1] = np.full(
135-
(1, self.observation_shape, self.observation_shape), 2)
134+
expected_state[:, :, :, -1] = np.full((1,) + self.observation_shape, 2)
136135
self.assertAllEqual(agent.state, expected_state)
137136
self.assertAllEqual(agent._observation, second_observation[:, :, 0])
138137
# training_steps is incremented since we set eval_mode to False.
@@ -145,8 +144,7 @@ def testStepEval(self):
145144
"""
146145
with tf.Session() as sess:
147146
agent = self._create_test_agent(sess)
148-
base_observation = np.ones(
149-
[self.observation_shape, self.observation_shape, 1])
147+
base_observation = np.ones(self.observation_shape + (1,))
150148
# This will reset state and choose a first action.
151149
agent.begin_episode(base_observation)
152150
# We mock the replay buffer to verify how the agent interacts with it.
@@ -163,12 +161,11 @@ def testStepEval(self):
163161
stack_pos = step - num_steps - 1
164162
if stack_pos >= -self.stack_size:
165163
expected_state[:, :, :, stack_pos] = np.full(
166-
(1, self.observation_shape, self.observation_shape), step)
164+
(1,) + self.observation_shape, step)
167165
self.assertAllEqual(agent.state, expected_state)
168166
self.assertAllEqual(
169167
agent._last_observation,
170-
np.ones([self.observation_shape, self.observation_shape]) *
171-
(num_steps - 1))
168+
np.ones(self.observation_shape) * (num_steps - 1))
172169
self.assertAllEqual(agent._observation, observation[:, :, 0])
173170
# No training happens in eval mode.
174171
self.assertEqual(agent.training_steps, 0)
@@ -183,8 +180,7 @@ def testStepTrain(self):
183180
with tf.Session() as sess:
184181
agent = self._create_test_agent(sess)
185182
agent.eval_mode = False
186-
base_observation = np.ones(
187-
[self.observation_shape, self.observation_shape, 1])
183+
base_observation = np.ones(self.observation_shape + (1,))
188184
# We mock the replay buffer to verify how the agent interacts with it.
189185
agent._replay = test_utils.MockReplayBuffer()
190186
self.evaluate(tf.global_variables_initializer())
@@ -203,7 +199,7 @@ def testStepTrain(self):
203199
stack_pos = step - num_steps - 1
204200
if stack_pos >= -self.stack_size:
205201
expected_state[:, :, :, stack_pos] = np.full(
206-
(1, self.observation_shape, self.observation_shape), step)
202+
(1,) + self.observation_shape, step)
207203
self.assertEqual(agent._replay.add.call_count, step)
208204
mock_args, _ = agent._replay.add.call_args
209205
self.assertAllEqual(last_observation[:, :, 0], mock_args[0])
@@ -213,8 +209,7 @@ def testStepTrain(self):
213209
self.assertAllEqual(agent.state, expected_state)
214210
self.assertAllEqual(
215211
agent._last_observation,
216-
np.full((self.observation_shape, self.observation_shape),
217-
num_steps - 1))
212+
np.full(self.observation_shape, num_steps - 1))
218213
self.assertAllEqual(agent._observation, observation[:, :, 0])
219214
# We expect one more than num_steps because of the call to begin_episode.
220215
self.assertEqual(agent.training_steps, num_steps + 1)
@@ -228,6 +223,78 @@ def testStepTrain(self):
228223
self.assertAllEqual(1, mock_args[2]) # Reward received.
229224
self.assertTrue(mock_args[3]) # is_terminal
230225

226+
def testNonTupleObservationShape(self):
227+
with self.assertRaises(AssertionError):
228+
self.observation_shape = 84
229+
with tf.Session() as sess:
230+
_ = self._create_test_agent(sess)
231+
232+
def _testCustomShapes(self, shape, dtype, stack_size):
233+
self.observation_shape = shape
234+
self.observation_dtype = dtype
235+
self.stack_size = stack_size
236+
self.zero_state = np.zeros((1,) + shape + (stack_size,))
237+
with tf.Session() as sess:
238+
agent = self._create_test_agent(sess)
239+
agent.eval_mode = False
240+
base_observation = np.ones(self.observation_shape + (1,))
241+
# We mock the replay buffer to verify how the agent interacts with it.
242+
agent._replay = test_utils.MockReplayBuffer()
243+
self.evaluate(tf.global_variables_initializer())
244+
# This will reset state and choose a first action.
245+
agent.begin_episode(base_observation)
246+
observation = base_observation
247+
248+
expected_state = self.zero_state
249+
num_steps = 10
250+
for step in range(1, num_steps + 1):
251+
# We make observation a multiple of step for testing purposes (to
252+
# uniquely identify each observation).
253+
last_observation = observation
254+
observation = base_observation * step
255+
self.assertEqual(agent.step(reward=1, observation=observation), 0)
256+
stack_pos = step - num_steps - 1
257+
if stack_pos >= -self.stack_size:
258+
expected_state[..., stack_pos] = np.full(
259+
(1,) + self.observation_shape, step)
260+
self.assertEqual(agent._replay.add.call_count, step)
261+
mock_args, _ = agent._replay.add.call_args
262+
self.assertAllEqual(last_observation[..., 0], mock_args[0])
263+
self.assertAllEqual(0, mock_args[1]) # Action selected.
264+
self.assertAllEqual(1, mock_args[2]) # Reward received.
265+
self.assertFalse(mock_args[3]) # is_terminal
266+
self.assertAllEqual(agent.state, expected_state)
267+
self.assertAllEqual(
268+
agent._last_observation,
269+
np.full(self.observation_shape, num_steps - 1))
270+
self.assertAllEqual(agent._observation, observation[..., 0])
271+
# We expect one more than num_steps because of the call to begin_episode.
272+
self.assertEqual(agent.training_steps, num_steps + 1)
273+
self.assertEqual(agent._replay.add.call_count, num_steps)
274+
275+
agent.end_episode(reward=1)
276+
self.assertEqual(agent._replay.add.call_count, num_steps + 1)
277+
mock_args, _ = agent._replay.add.call_args
278+
self.assertAllEqual(observation[..., 0], mock_args[0])
279+
self.assertAllEqual(0, mock_args[1]) # Action selected.
280+
self.assertAllEqual(1, mock_args[2]) # Reward received.
281+
self.assertTrue(mock_args[3]) # is_terminal
282+
283+
def testStepTrainCustomObservationShapes(self):
284+
custom_shapes = [(1,), (4, 4), (6, 1), (1, 6), (1, 1, 6), (6, 6, 6, 6)]
285+
for shape in custom_shapes:
286+
self._testCustomShapes(shape, tf.uint8, 1)
287+
288+
def testStepTrainCustomTypes(self):
289+
custom_types = [tf.float32, tf.uint8, tf.int64]
290+
for dtype in custom_types:
291+
self._testCustomShapes((4, 4), dtype, 1)
292+
293+
def testStepTrainCustomStackSizes(self):
294+
custom_stack_sizes = [1, 4, 8]
295+
for stack_size in custom_stack_sizes:
296+
self._testCustomShapes((4, 4), tf.uint8, stack_size)
297+
231298
def testLinearlyDecayingEpsilon(self):
232299
"""Test the functionality of the linearly_decaying_epsilon function."""
233300
decay_period = 100

tests/dopamine/agents/implicit_quantile/implicit_quantile_agent_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def setUp(self):
3737
self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE
3838
self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE
3939
self.ones_state = np.ones(
40-
[1, self.observation_shape, self.observation_shape, self.stack_size])
40+
(1,) + self.observation_shape + (self.stack_size,))
4141

4242
def _create_test_agent(self, sess):
4343

0 commit comments

Comments
 (0)