Skip to content

Add SPR implementation to atari_100k lab #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions dopamine/labs/atari_100k/configs/SPR.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Data Regularlized-Q (DrQ) form Kostrikov et al. (2020)
import dopamine.jax.agents.dqn.dqn_agent
import dopamine.jax.networks
import dopamine.discrete_domains.gym_lib
import dopamine.discrete_domains.run_experiment
import dopamine.replay_memory.prioritized_replay_buffer
import dopamine.labs.atari_100k.spr_networks
import dopamine.labs.atari_100k.spr_agent

# Parameters specific to DrQ are higlighted by comments
JaxDQNAgent.gamma = 0.99
JaxDQNAgent.update_horizon = 10 # DrQ (instead of 3)
JaxDQNAgent.min_replay_history = 2000 # DrQ (instead of 20000)
JaxDQNAgent.update_period = 1 # DrQ (rather than 4)
JaxDQNAgent.target_update_period = 1 # DrQ (rather than 8000)
JaxDQNAgent.epsilon_train = 0.00
JaxDQNAgent.epsilon_eval = 0.001
JaxDQNAgent.epsilon_decay_period = 2001 # DrQ
JaxDQNAgent.optimizer = 'adam'

SPRAgent.noisy = True
SPRAgent.dueling = True
SPRAgent.double_dqn = True
SPRAgent.distributional = True
SPRAgent.num_atoms = 51
SPRAgent.log_every = 100
SPRAgent.num_updates_per_train_step = 2
SPRAgent.spr_weight = 5
SPRAgent.jumps = 5
SPRAgent.data_augmentation = True
SPRAgent.replay_scheme = 'prioritized'
SPRAgent.network = @spr_networks.SPRNetwork
SPRAgent.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon

# Note these parameters are from DER (van Hasselt et al, 2019)
create_optimizer.learning_rate = 0.0001
create_optimizer.eps = 0.00015

atari_lib.create_atari_environment.game_name = 'Pong'
# Atari 100K benchmark doesn't use sticky actions.
atari_lib.create_atari_environment.sticky_actions = False
AtariPreprocessing.terminal_on_life_loss = True
Runner.num_iterations = 1
Runner.training_steps = 100000 # agent steps
MaxEpisodeEvalRunner.num_eval_episodes = 100 # agent episodes
Runner.max_steps_per_episode = 27000 # agent steps

DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.replay_capacity = 200000
DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.batch_size = 32
DeterministicOutOfGraphTemporalReplayBuffer.replay_capacity = 200000
DeterministicOutOfGraphTemporalReplayBuffer.batch_size = 32
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2021 The Atari 100k Precipice Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sum tree data structure that uses JAX for controlling randomness."""

from dopamine.replay_memory import sum_tree
import jax
from jax import numpy as jnp
import numpy as np
import time
import functools


@jax.jit
def step(i, args):
query_value, index, nodes = args
left_child = index * 2 + 1
left_sum = nodes[left_child]
index = jax.lax.cond(query_value < left_sum, lambda x: x, lambda x: x + 1,
left_child)
query_value = jax.lax.cond(query_value < left_sum, lambda x: x,
lambda x: x - left_sum, query_value)
return query_value, index, nodes


@jax.jit
@functools.partial(jax.vmap, in_axes=(None, None, 0, None, None))
def parallel_stratified_sample(rng, nodes, i, n, depth):
rng = jax.random.fold_in(rng, i)
total_priority = nodes[0]
upper_bound = (i + 1) / n
lower_bound = i / n
query = jax.random.uniform(rng, minval=lower_bound, maxval=upper_bound)
_, index, _ = jax.lax.fori_loop(0, depth, step,
(query * total_priority, 0, nodes))
return index


class DeterministicSumTree(sum_tree.SumTree):
"""A sum tree data structure for storing replay priorities.

In contrast to the original implementation, this uses JAX for handling
randomness, which allows us to reproduce the same results when using the same
seed.
"""

def __init__(self, capacity):
"""Creates the sum tree data structure for the given replay capacity.
Args:
capacity: int, the maximum number of elements that can be stored in this
data structure.
Raises:
ValueError: If requested capacity is not positive.
"""
assert isinstance(capacity, int)
if capacity <= 0:
raise ValueError(
'Sum tree capacity should be positive. Got: {}'.format(capacity))

self.nodes = []
self.depth = int(np.ceil(np.log2(capacity)))
self.low_idx = (2**self.depth) - 1 # pri_idx + low_idx -> tree_idx
self.high_idx = capacity + self.low_idx
self.nodes = np.zeros(2**(self.depth + 1) - 1) # Double precision.

self.max_recorded_priority = 1.0

def _total_priority(self):
"""Returns the sum of all priorities stored in this sum tree.
Returns:
float, sum of priorities stored in this sum tree.
"""
return self.nodes[0]

def sample(self, rng, query_value=None):
"""Samples an element from the sum tree.
This function is designed to be jitted, so it does not have the same
checks as the original.
"""
# Sample a value in range [0, R), where R is the value stored at the root.
nodes = jnp.array(self.nodes)
query_value = (
jax.random.uniform(rng) if query_value is None else query_value)
query_value *= self._total_priority()

# Now traverse the sum tree.
_, index, _ = jax.lax.fori_loop(0, self.depth, step,
(query_value, 0, nodes))
return index - self.low_idx

def stratified_sample(self, batch_size, rng):
"""Performs stratified sampling using the sum tree."""
if self._total_priority() == 0.0:
raise Exception('Cannot sample from an empty sum tree.')
indices = parallel_stratified_sample(rng, self.nodes,
jnp.arange(batch_size), batch_size,
self.depth)
return indices - self.low_idx

def get(self, node_index):
"""Returns the value of the leaf node corresponding to the index.
Args:
node_index: The index of the leaf node.
Returns:
The value of the leaf node.
"""
return self.nodes[node_index + self.low_idx]

def set(self, node_index, value):
"""Sets the value of a leaf node and updates internal nodes accordingly.
This operation takes O(log(capacity)).
Args:
node_index: int, the index of the leaf node to be updated.
value: float, the value which we assign to the node. This value must be
nonnegative. Setting value = 0 will cause the element to never be
sampled.
Raises:
ValueError: If the given value is negative.
"""
if value < 0.0:
raise ValueError(
'Sum tree values should be nonnegative. Got {}'.format(value))
node_index = node_index + self.low_idx
self.max_recorded_priority = max(value, self.max_recorded_priority)

delta_value = value - self.nodes[node_index]

# Now traverse back the tree, adjusting all sums along the way.
for _ in reversed(range(self.depth)):
# Note: Adding a delta leads to some tolerable numerical inaccuracies.
self.nodes[node_index] += delta_value
node_index = (node_index - 1) // 2

self.nodes[node_index] += delta_value
assert node_index == 0, ('Sum tree traversal failed, final node index '
'is not 0.')
Loading