Skip to content

Commit 8f3fb8e

Browse files
committed
Add PER
1 parent e1c576f commit 8f3fb8e

File tree

18 files changed

+337
-35
lines changed

18 files changed

+337
-35
lines changed

examples/torch/dqn_atari.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from garage.envs.wrappers.stack_frames import StackFrames
2424
from garage.experiment.deterministic import set_seed
2525
from garage.np.exploration_policies import EpsilonGreedyPolicy
26-
from garage.replay_buffer import PathBuffer
26+
from garage.replay_buffer import PERReplayBuffer
2727
from garage.sampler import FragmentWorker, LocalSampler
2828
from garage.torch import set_gpu_mode
2929
from garage.torch.algos import DQN
@@ -40,6 +40,9 @@
4040
n_train_steps=125,
4141
target_update_freq=2,
4242
buffer_batch_size=32,
43+
double_q=True,
44+
per_beta_init=0.4,
45+
per_alpha=0.6,
4346
max_epsilon=1.0,
4447
min_epsilon=0.01,
4548
decay_ratio=0.1,
@@ -104,7 +107,7 @@ def main(env=None,
104107

105108

106109
# pylint: disable=unused-argument
107-
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30)
110+
@wrap_experiment(snapshot_mode='none')
108111
def dqn_atari(ctxt=None,
109112
env=None,
110113
seed=24,
@@ -150,8 +153,12 @@ def dqn_atari(ctxt=None,
150153
steps_per_epoch = hyperparams['steps_per_epoch']
151154
sampler_batch_size = hyperparams['sampler_batch_size']
152155
num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
153-
replay_buffer = PathBuffer(
154-
capacity_in_transitions=hyperparams['buffer_size'])
156+
157+
replay_buffer = PERReplayBuffer(hyperparams['buffer_size'],
158+
num_timesteps,
159+
env.spec,
160+
alpha=hyperparams['per_alpha'],
161+
beta_init=hyperparams['per_beta_init'])
155162

156163
qf = DiscreteCNNQFunction(
157164
env_spec=env.spec,
@@ -179,6 +186,7 @@ def dqn_atari(ctxt=None,
179186
replay_buffer=replay_buffer,
180187
steps_per_epoch=steps_per_epoch,
181188
qf_lr=hyperparams['lr'],
189+
double_q=hyperparams['double_q'],
182190
clip_gradient=hyperparams['clip_gradient'],
183191
discount=hyperparams['discount'],
184192
min_buffer_size=hyperparams['min_buffer_size'],

src/garage/envs/gym_env.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# entry points don't close their viewer windows.
1414
KNOWN_GYM_NOT_CLOSE_VIEWER = [
1515
# Please keep alphabetized
16-
'gym.envs.atari',
1716
'gym.envs.box2d',
1817
'gym.envs.classic_control'
1918
]

src/garage/replay_buffer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
from garage.replay_buffer.her_replay_buffer import HERReplayBuffer
66
from garage.replay_buffer.path_buffer import PathBuffer
7+
from garage.replay_buffer.per_replay_buffer import PERReplayBuffer
78
from garage.replay_buffer.replay_buffer import ReplayBuffer
89

9-
__all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']
10+
__all__ = ['PERReplayBuffer', 'ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']

src/garage/replay_buffer/path_buffer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,15 @@ def sample_transitions(self, batch_size):
119119
120120
Returns:
121121
dict: A dict of arrays of shape (batch_size, flat_dim).
122+
np.ndarray: Weights of the timesteps.
123+
np.ndarray: Indices of sampled timesteps
124+
in the replay buffer.
122125
123126
"""
124127
idx = np.random.randint(self._transitions_stored, size=batch_size)
125-
return {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}
128+
w = np.ones(batch_size)
129+
data = {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}
130+
return data, w, idx
126131

127132
def sample_timesteps(self, batch_size):
128133
"""Sample a batch of timesteps from the buffer.
@@ -132,9 +137,12 @@ def sample_timesteps(self, batch_size):
132137
133138
Returns:
134139
TimeStepBatch: The batch of timesteps.
140+
np.ndarray: Weights of the timesteps.
141+
np.ndarray: Indices of sampled timesteps
142+
in the replay buffer.
135143
136144
"""
137-
samples = self.sample_transitions(batch_size)
145+
samples, w, idx = self.sample_transitions(batch_size)
138146
step_types = np.array([
139147
StepType.TERMINAL if terminal else StepType.MID
140148
for terminal in samples['terminals'].reshape(-1)
@@ -147,7 +155,7 @@ def sample_timesteps(self, batch_size):
147155
next_observations=samples['next_observations'],
148156
step_types=step_types,
149157
env_infos={},
150-
agent_infos={})
158+
agent_infos={}), w, idx
151159

152160
def _next_path_segments(self, n_indices):
153161
"""Compute where the next path should be stored.
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Prioritized Experience Replay."""
2+
3+
import numpy as np
4+
5+
from garage import StepType, TimeStepBatch
6+
from garage.replay_buffer.path_buffer import PathBuffer
7+
8+
9+
class PERReplayBuffer(PathBuffer):
10+
"""Replay buffer for PER (Prioritized Experience Replay).
11+
12+
PER assigns priorities to transitions in the buffer. Typically
13+
these priority of each transition is proportional to the corresponding
14+
loss computed at each update step. The priorities are then used to create
15+
a probability distribution when sampling such that higher priority
16+
transitions are sampled more frequently. For more see
17+
https://arxiv.org/abs/1511.05952.
18+
19+
Args:
20+
capacity_in_transitions (int): total size of transitions in the buffer.
21+
env_spec (EnvSpec): Environment specification.
22+
total_timesteps (int): Total timesteps the experiment will run for.
23+
This is used to calculate the beta parameter when sampling.
24+
alpha (float): hyperparameter that controls the degree of
25+
prioritization. Typically between [0, 1], where 0 corresponds to
26+
no prioritization (uniform sampling).
27+
beta_init (float): Initial value of beta exponent in importance
28+
sampling. Beta is linearly annealed from beta_init to 1
29+
over total_timesteps.
30+
"""
31+
32+
def __init__(self,
33+
capacity_in_transitions,
34+
total_timesteps,
35+
env_spec,
36+
alpha=0.6,
37+
beta_init=0.5):
38+
self._alpha = alpha
39+
self._beta_init = beta_init
40+
self._total_timesteps = total_timesteps
41+
self._curr_timestep = 0
42+
self._priorities = np.zeros((capacity_in_transitions, ), np.float32)
43+
self._rng = np.random.default_rng()
44+
super().__init__(capacity_in_transitions, env_spec)
45+
46+
def sample_timesteps(self, batch_size):
47+
"""Sample a batch of timesteps from the buffer.
48+
49+
Args:
50+
batch_size (int): Number of timesteps to sample.
51+
52+
Returns:
53+
TimeStepBatch: The batch of timesteps.
54+
np.ndarray: Weights of the timesteps.
55+
np.ndarray: Indices of sampled timesteps
56+
in the replay buffer.
57+
58+
"""
59+
samples, w, idx = self.sample_transitions(batch_size)
60+
step_types = np.array([
61+
StepType.TERMINAL if terminal else StepType.MID
62+
for terminal in samples['terminals'].reshape(-1)
63+
],
64+
dtype=StepType)
65+
return TimeStepBatch(env_spec=self._env_spec,
66+
observations=samples['observations'],
67+
actions=samples['actions'],
68+
rewards=samples['rewards'],
69+
next_observations=samples['next_observations'],
70+
step_types=step_types,
71+
env_infos={},
72+
agent_infos={}), w, idx
73+
74+
def sample_transitions(self, batch_size):
75+
"""Sample a batch of transitions from the buffer.
76+
77+
Args:
78+
batch_size (int): Number of transitions to sample.
79+
80+
Returns:
81+
dict: A dict of arrays of shape (batch_size, flat_dim).
82+
np.ndarray: Weights of the timesteps.
83+
np.ndarray: Indices of sampled timesteps
84+
in the replay buffer.
85+
86+
"""
87+
priorities = self._priorities
88+
if self._transitions_stored < self._capacity:
89+
priorities = self._priorities[:self._transitions_stored]
90+
probs = priorities**self._alpha
91+
probs /= probs.sum()
92+
idx = self._rng.choice(self._transitions_stored,
93+
size=batch_size,
94+
p=probs)
95+
96+
beta = self._beta_init + self._curr_timestep * (
97+
1.0 - self._beta_init) / self._total_timesteps
98+
beta = min(1.0, beta)
99+
transitions = {
100+
key: buf_arr[idx]
101+
for key, buf_arr in self._buffer.items()
102+
}
103+
104+
w = (self._transitions_stored * probs[idx])**(-beta)
105+
w /= w.max()
106+
w = np.array(w)
107+
108+
return transitions, w, idx
109+
110+
def update_priorities(self, indices, priorities):
111+
"""Update priorities of timesteps.
112+
113+
Args:
114+
indices (np.ndarray): Array of indices corresponding to the
115+
timesteps/priorities to update.
116+
priorities (list[float]): new priorities to set.
117+
118+
"""
119+
for idx, priority in zip(indices, priorities):
120+
self._priorities[int(idx)] = priority
121+
122+
def add_path(self, path):
123+
"""Add a path to the buffer.
124+
125+
This differs from the underlying buffer's add_path method
126+
in that the priorities for the new samples are set to
127+
the maximum of all priorities in the buffer.
128+
129+
Args:
130+
path (dict): A dict of array of shape (path_len, flat_dim).
131+
132+
"""
133+
path_len = len(path['observations'])
134+
self._curr_timestep += path_len
135+
136+
# find the indices where the path will be stored
137+
first_seg, second_seg = self._next_path_segments(path_len)
138+
139+
# set priorities for new timesteps = max(self._priorities)
140+
# or 1 if buffer is empty
141+
max_priority = self._priorities.max() or 1.
142+
self._priorities[first_seg.start:first_seg.stop] = max_priority
143+
if second_seg != range(0, 0):
144+
self._priorities[second_seg.start:second_seg.stop] = max_priority
145+
super().add_path(path)

src/garage/replay_buffer/replay_buffer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def sample(self, batch_size):
5555
5656
Args:
5757
batch_size(int): The number of transitions to be sampled.
58+
np.ndarray: Weights of the timesteps.
59+
np.ndarray: Indices of sampled timesteps
60+
in the replay buffer.
5861
5962
"""
6063
raise NotImplementedError

src/garage/tf/algos/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _optimize_policy(self):
350350
float: Q value predicted by the q network.
351351
352352
"""
353-
timesteps = self._replay_buffer.sample_timesteps(
353+
timesteps, _, _ = self._replay_buffer.sample_timesteps(
354354
self._buffer_batch_size)
355355

356356
observations = timesteps.observations

src/garage/tf/algos/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _optimize_policy(self):
258258
numpy.float64: Loss of policy.
259259
260260
"""
261-
timesteps = self._replay_buffer.sample_timesteps(
261+
timesteps, _, _ = self._replay_buffer.sample_timesteps(
262262
self._buffer_batch_size)
263263

264264
observations = timesteps.observations

src/garage/tf/algos/td3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def _optimize_policy(self, itr):
371371
float: Q value predicted by the q network.
372372
373373
"""
374-
timesteps = self._replay_buffer.sample_timesteps(
374+
timesteps, _, _ = self._replay_buffer.sample_timesteps(
375375
self._buffer_batch_size)
376376

377377
observations = timesteps.observations

src/garage/torch/algos/ddpg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import numpy as np
77
import torch
88

9-
from garage import (_Default,
10-
log_performance,
11-
make_optimizer,
9+
from garage import (_Default, log_performance, make_optimizer,
1210
obtain_evaluation_episodes)
1311
from garage.np.algos import RLAlgorithm
1412
from garage.sampler import FragmentWorker, LocalSampler
@@ -188,7 +186,7 @@ def train_once(self, itr, episodes):
188186
for _ in range(self._n_train_steps):
189187
if (self.replay_buffer.n_transitions_stored >=
190188
self._min_buffer_size):
191-
samples = self.replay_buffer.sample_transitions(
189+
samples, _, _ = self.replay_buffer.sample_transitions(
192190
self._buffer_batch_size)
193191
samples['rewards'] *= self._reward_scale
194192
qf_loss, y, q, policy_loss = torch_to_np(

src/garage/torch/algos/dqn.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from garage import _Default, log_performance, make_optimizer
1111
from garage._functions import obtain_evaluation_episodes
1212
from garage.np.algos import RLAlgorithm
13+
from garage.replay_buffer import PERReplayBuffer
1314
from garage.sampler import FragmentWorker
14-
from garage.torch import global_device, np_to_torch
15+
from garage.torch import global_device, np_to_torch, torch_to_np
1516

1617

1718
class DQN(RLAlgorithm):
@@ -122,6 +123,9 @@ def __init__(
122123
self._qf_optimizer = make_optimizer(qf_optimizer,
123124
module=self._qf,
124125
lr=qf_lr)
126+
127+
self._prioritized_replay = isinstance(self.replay_buffer,
128+
PERReplayBuffer)
125129
self._eval_env = eval_env
126130

127131
def train(self, trainer):
@@ -192,10 +196,12 @@ def _train_once(self, itr, episodes):
192196
for _ in range(self._n_train_steps):
193197
if (self.replay_buffer.n_transitions_stored >=
194198
self._min_buffer_size):
195-
timesteps = self.replay_buffer.sample_timesteps(
196-
self._buffer_batch_size)
197-
qf_loss, y, q = tuple(v.cpu().numpy()
198-
for v in self._optimize_qf(timesteps))
199+
timesteps, weights, indices = (
200+
self.replay_buffer.sample_timesteps(
201+
self._buffer_batch_size))
202+
qf_loss, y, q = tuple(
203+
v.cpu().numpy()
204+
for v in self._optimize_qf(timesteps, weights, indices))
199205

200206
self._episode_qf_losses.append(qf_loss)
201207
self._epoch_ys.append(y)
@@ -228,11 +234,15 @@ def _log_eval_results(self, epoch):
228234
tabular.record('QFunction/AverageAbsY',
229235
np.mean(np.abs(self._epoch_ys)))
230236

231-
def _optimize_qf(self, timesteps):
237+
def _optimize_qf(self, timesteps, weights=None, indices=None):
232238
"""Perform algorithm optimizing.
233239
234240
Args:
235241
timesteps (TimeStepBatch): Processed batch data.
242+
weights (np.ndarray[float]): Weights used by PER when updating
243+
the network.
244+
indices (list[int or float]): Indices of the sampled
245+
timesteps in the replay buffer.
236246
237247
Returns:
238248
qval_loss: Loss of Q-value predicted by the Q-network.
@@ -274,7 +284,15 @@ def _optimize_qf(self, timesteps):
274284
# optimize qf
275285
qvals = self._qf(inputs)
276286
selected_qs = torch.sum(qvals * actions, axis=1)
277-
qval_loss = F.smooth_l1_loss(selected_qs, y_target)
287+
qval_loss = F.smooth_l1_loss(selected_qs, y_target, reduction='none')
288+
289+
if self._prioritized_replay:
290+
qval_loss *= np_to_torch(weights)
291+
priorities = qval_loss + 1e-5 # offset to avoid 0 priorities
292+
priorities = torch_to_np(priorities.data.cpu())
293+
self.replay_buffer.update_priorities(indices, priorities)
294+
295+
qval_loss = qval_loss.mean()
278296

279297
self._qf_optimizer.zero_grad()
280298
qval_loss.backward()

0 commit comments

Comments
 (0)