Skip to content

Commit 67901fd

Browse files
authored
Merge pull request #57 from muupan/pcl-prioritized-replay
Use Prioritized Replay for PCL
2 parents f4cff28 + 5c83182 commit 67901fd

File tree

6 files changed

+233
-65
lines changed

6 files changed

+233
-65
lines changed

chainerrl/agents/pcl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ def update_from_replay(self):
270270

271271
episodes = self.replay_buffer.sample_episodes(
272272
self.batchsize, max_len=self.t_max)
273+
if isinstance(episodes, tuple):
274+
# Prioritized replay
275+
episodes, weights = episodes
276+
else:
277+
weights = [1] * len(episodes)
273278
sorted_episodes = list(reversed(sorted(episodes, key=len)))
274279
max_epi_len = len(sorted_episodes[0])
275280

@@ -326,7 +331,8 @@ def update_from_replay(self):
326331
values=e_values,
327332
next_values=e_next_values,
328333
log_probs=e_log_probs))
329-
loss = chainerrl.functions.sum_arrays(losses) / self.batchsize
334+
loss = chainerrl.functions.weighted_sum_arrays(
335+
losses, weights) / self.batchsize
330336
self.update(loss)
331337

332338
def update_on_policy(self, statevar):

chainerrl/misc/prioritized.py

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,46 @@
1+
from __future__ import unicode_literals
2+
from __future__ import print_function
3+
from __future__ import division
4+
from __future__ import absolute_import
5+
from builtins import * # NOQA
6+
from future import standard_library
7+
standard_library.install_aliases()
18
import random
29

10+
import numpy as np
11+
312

413
class PrioritizedBuffer (object):
5-
def __init__(self, capacity=None):
14+
15+
def __init__(self, capacity=None, wait_priority_after_sampling=True):
616
self.capacity = capacity
717
self.data = []
818
self.priority_tree = SumTree()
919
self.data_inf = []
20+
self.wait_priority_after_sampling = wait_priority_after_sampling
1021
self.flag_wait_priority = False
1122

1223
def __len__(self):
1324
return len(self.data) + len(self.data_inf)
1425

15-
def append(self, value):
16-
# new values are the most prioritized
17-
self.data_inf.append(value)
18-
if self.capacity is not None and len(self) > self.capacity:
26+
def append(self, value, priority=None):
27+
if self.capacity is not None and len(self) == self.capacity:
1928
self.pop()
29+
if priority is not None:
30+
# Append with a given priority
31+
i = len(self.data)
32+
self.data.append(value)
33+
self.priority_tree[i] = priority
34+
else:
35+
# Append with the highest priority
36+
self.data_inf.append(value)
2037

2138
def _pop_random_data_inf(self):
2239
assert self.data_inf
2340
n = len(self.data_inf)
2441
i = random.randrange(n)
2542
ret = self.data_inf[i]
26-
self.data_inf[i] = self.data_inf[n-1]
43+
self.data_inf[i] = self.data_inf[n - 1]
2744
self.data_inf.pop()
2845
return ret
2946

@@ -33,47 +50,96 @@ def pop(self):
3350
Not prioritized.
3451
"""
3552
assert len(self) > 0
36-
assert not self.flag_wait_priority
53+
assert (not self.wait_priority_after_sampling or
54+
not self.flag_wait_priority)
3755
n = len(self.data)
3856
if n == 0:
3957
return self._pop_random_data_inf()
4058
i = random.randrange(0, n)
4159
# remove i-th
42-
self.priority_tree[i] = self.priority_tree[n-1]
43-
del self.priority_tree[n-1]
60+
self.priority_tree[i] = self.priority_tree[n - 1]
61+
del self.priority_tree[n - 1]
4462
ret = self.data[i]
45-
self.data[i] = self.data[n-1]
46-
del self.data[n-1]
63+
self.data[i] = self.data[n - 1]
64+
del self.data[n - 1]
4765
return ret
4866

49-
def sample(self, n):
50-
"""Sample n distinct elements"""
67+
def _prioritized_sample_indices_and_probabilities(self, n):
5168
assert 0 <= n <= len(self)
52-
assert not self.flag_wait_priority
5369
indices, probabilities = self.priority_tree.prioritized_sample(
54-
max(0, n - len(self.data_inf)), remove=True)
55-
sampled = []
56-
for i in indices:
57-
sampled.append(self.data[i])
58-
while len(sampled) < n and len(self.data_inf) > 0:
70+
max(0, n - len(self.data_inf)),
71+
remove=self.wait_priority_after_sampling)
72+
while len(indices) < n:
5973
i = len(self.data)
6074
e = self._pop_random_data_inf()
6175
self.data.append(e)
6276
del self.priority_tree[i]
6377
indices.append(i)
6478
probabilities.append(None)
65-
sampled.append(self.data[i])
79+
return indices, probabilities
80+
81+
def _sample_indices_and_probabilities(self, n, uniform_ratio):
82+
if uniform_ratio > 0:
83+
# Mix uniform samples and prioritized samples
84+
n_uniform = np.random.binomial(n, uniform_ratio)
85+
n_prioritized = n - n_uniform
86+
pr_indices, pr_probs = \
87+
self._prioritized_sample_indices_and_probabilities(
88+
n_prioritized)
89+
un_indices, un_probs = \
90+
self._uniform_sample_indices_and_probabilities(
91+
n_uniform)
92+
indices = pr_indices + un_indices
93+
# Note: when uniform samples and prioritized samples are mixed,
94+
# resulting probabilities are not the true probabilities for each
95+
# entry to be sampled.
96+
probabilities = pr_probs + un_probs
97+
return indices, probabilities
98+
else:
99+
# Only prioritized samples
100+
return self._prioritized_sample_indices_and_probabilities(n)
101+
102+
def sample(self, n, uniform_ratio=0):
103+
"""Sample data along with their corresponding probabilities.
104+
105+
Args:
106+
n (int): Number of data to sample.
107+
uniform_ratio (float): Ratio of uniformly sampled data.
108+
Returns:
109+
sampled data (list)
110+
probabitilies (list)
111+
"""
112+
assert (not self.wait_priority_after_sampling or
113+
not self.flag_wait_priority)
114+
indices, probabilities = self._sample_indices_and_probabilities(
115+
n, uniform_ratio=uniform_ratio)
116+
sampled = [self.data[i] for i in indices]
66117
self.sampled_indices = indices
67118
self.flag_wait_priority = True
68119
return sampled, probabilities
69120

70121
def set_last_priority(self, priority):
71-
assert self.flag_wait_priority
122+
assert (not self.wait_priority_after_sampling or
123+
self.flag_wait_priority)
72124
assert all([p > 0.0 for p in priority])
73125
assert len(self.sampled_indices) == len(priority)
74126
for i, p in zip(self.sampled_indices, priority):
75127
self.priority_tree[i] = p
76128
self.flag_wait_priority = False
129+
self.sampled_indices = []
130+
131+
def _uniform_sample_indices_and_probabilities(self, n):
132+
indices = random.sample(range(len(self.data)),
133+
max(0, n - len(self.data_inf)))
134+
probabilities = [1 / len(self)] * len(indices)
135+
while len(indices) < n:
136+
i = len(self.data)
137+
e = self._pop_random_data_inf()
138+
self.data.append(e)
139+
del self.priority_tree[i]
140+
indices.append(i)
141+
probabilities.append(None)
142+
return indices, probabilities
77143

78144

79145
class SumTree (object):
@@ -119,17 +185,18 @@ def _center(self):
119185

120186
def _allocindex(self, ix):
121187
if self.bd is None:
122-
self.bd = (ix, ix+1)
188+
self.bd = (ix, ix + 1)
123189
while ix >= self.bd[1]:
124-
r_bd = (self.bd[1], self.bd[1]*2 - self.bd[0])
190+
r_bd = (self.bd[1], self.bd[1] * 2 - self.bd[0])
125191
l = SumTree(self.bd, self.l, self.r, self.s)
192+
126193
r = SumTree(bd=r_bd)._initdescendant()
127194
self.bd = (l.bd[0], r.bd[1])
128195
self.l = l
129196
self.r = r
130197
# no need to update self.s because self.r.s == 0
131198
while ix < self.bd[0]:
132-
l_bd = (self.bd[0]*2 - self.bd[1], self.bd[0])
199+
l_bd = (self.bd[0] * 2 - self.bd[1], self.bd[0])
133200
l = SumTree(bd=l_bd)._initdescendant()
134201
r = SumTree(self.bd, self.l, self.r, self.s)
135202
self.bd = (l.bd[0], r.bd[1])

chainerrl/replay_buffer.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max):
7575
assert 0.0 <= beta0 <= 1.0
7676
self.alpha = alpha
7777
self.beta = beta0
78-
self.beta_add = (1.0 - beta0) / betasteps
78+
if betasteps is None:
79+
self.beta_add = 0
80+
else:
81+
self.beta_add = (1.0 - beta0) / betasteps
7982
self.eps = eps
8083
self.normalize_by_max = normalize_by_max
8184

@@ -206,31 +209,49 @@ class PrioritizedEpisodicReplayBuffer (
206209

207210
def __init__(self, capacity=None,
208211
alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8,
209-
normalize_by_max=True):
212+
normalize_by_max=True,
213+
default_priority_func=None,
214+
uniform_ratio=0,
215+
wait_priority_after_sampling=True,
216+
return_sample_weights=True):
210217
self.current_episode = []
211-
self.episodic_memory = PrioritizedBuffer(capacity=None)
218+
self.episodic_memory = PrioritizedBuffer(
219+
capacity=None,
220+
wait_priority_after_sampling=wait_priority_after_sampling)
212221
self.memory = deque(maxlen=capacity)
213222
self.capacity_left = capacity
223+
self.default_priority_func = default_priority_func
224+
self.uniform_ratio = uniform_ratio
225+
self.return_sample_weights = return_sample_weights
214226
PriorityWeightError.__init__(
215227
self, alpha, beta0, betasteps, eps, normalize_by_max)
216228

217229
def sample_episodes(self, n_episodes, max_len=None):
218230
"""Sample n unique samples from this replay buffer"""
219231
assert len(self.episodic_memory) >= n_episodes
220-
episodes, probabilities = self.episodic_memory.sample(n_episodes)
221-
weights = self.weights_from_probabilities(probabilities)
232+
episodes, probabilities = self.episodic_memory.sample(
233+
n_episodes, uniform_ratio=self.uniform_ratio)
222234
if max_len is not None:
223235
episodes = [random_subseq(ep, max_len) for ep in episodes]
224-
return episodes, weights
236+
if self.return_sample_weights:
237+
weights = self.weights_from_probabilities(probabilities)
238+
return episodes, weights
239+
else:
240+
return episodes
225241

226242
def update_errors(self, errors):
227243
self.episodic_memory.set_last_priority(
228244
self.priority_from_errors(errors))
229245

230246
def stop_current_episode(self):
231247
if self.current_episode:
248+
if self.default_priority_func is not None:
249+
priority = self.default_priority_func(self.current_episode)
250+
else:
251+
priority = None
232252
self.memory.extend(self.current_episode)
233-
self.episodic_memory.append(self.current_episode)
253+
self.episodic_memory.append(self.current_episode,
254+
priority=priority)
234255
if self.capacity_left is not None:
235256
self.capacity_left -= len(self.current_episode)
236257
self.current_episode = []

examples/gym/train_pcl_gym.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from chainerrl.optimizers import rmsprop_async
3030

3131

32+
def exp_return_of_episode(episode):
33+
return np.exp(sum(x['reward'] for x in episode))
34+
35+
3236
def main():
3337
import logging
3438

@@ -58,6 +62,8 @@ def main():
5862
parser.add_argument('--logger-level', type=int, default=logging.DEBUG)
5963
parser.add_argument('--monitor', action='store_true')
6064
parser.add_argument('--train-async', action='store_true', default=False)
65+
parser.add_argument('--prioritized-replay', action='store_true',
66+
default=False)
6167
parser.add_argument('--disable-online-update', action='store_true',
6268
default=False)
6369
parser.add_argument('--backprop-future-values', action='store_true',
@@ -126,6 +132,7 @@ def make_env(process_idx, test):
126132
)
127133

128134
if not args.train_async and args.gpu >= 0:
135+
chainer.cuda.get_device(args.gpu).use()
129136
model.to_gpu(args.gpu)
130137

131138
if args.train_async:
@@ -134,7 +141,18 @@ def make_env(process_idx, test):
134141
opt = chainer.optimizers.Adam(alpha=args.lr)
135142
opt.setup(model)
136143

137-
replay_buffer = chainerrl.replay_buffer.EpisodicReplayBuffer(10 ** 5)
144+
if args.prioritized_replay:
145+
replay_buffer = \
146+
chainerrl.replay_buffer.PrioritizedEpisodicReplayBuffer(
147+
capacity=5 * 10 ** 3,
148+
uniform_ratio=0.1,
149+
default_priority_func=exp_return_of_episode,
150+
wait_priority_after_sampling=False,
151+
return_sample_weights=False)
152+
else:
153+
replay_buffer = chainerrl.replay_buffer.EpisodicReplayBuffer(
154+
capacity=5 * 10 ** 3)
155+
138156
agent = chainerrl.agents.PCL(
139157
model, opt, replay_buffer=replay_buffer,
140158
t_max=args.t_max, gamma=0.99,

0 commit comments

Comments
 (0)