Skip to content

Commit 3976594

Browse files
committed
[Feature] Improve performance of compiled ReplayBuffer
1 parent 6799a7f commit 3976594

File tree

6 files changed

+230
-102
lines changed

6 files changed

+230
-102
lines changed

benchmarks/test_replaybuffer_benchmark.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
173173
)
174174

175175

176-
class create_tensor_rb:
177-
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
176+
class create_compiled_tensor_rb:
177+
def __init__(
178+
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
179+
):
178180
self.storage = storage
179181
self.rb = rb
180182
self.sampler = sampler
181-
self.size = size
183+
self.storage_size = storage_size
184+
self.data_size = data_size
182185
self.iters = iters
186+
self.compilable = compilable
183187

184188
def __call__(self):
185189
kwargs = {}
186190
if self.sampler is not None:
187191
kwargs["sampler"] = self.sampler()
188192
if self.storage is not None:
189-
kwargs["storage"] = self.storage(10 * self.size)
193+
kwargs["storage"] = self.storage(
194+
self.storage_size, compilable=self.compilable
195+
)
190196

191-
rb = self.rb(batch_size=3, **kwargs)
192-
data = torch.randn(self.size, 1)
197+
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
198+
data = torch.randn(self.data_size, 1)
193199
return ((rb, data, self.iters), {})
194200

195201

@@ -210,21 +216,32 @@ def fn(td):
210216

211217

212218
@pytest.mark.parametrize(
213-
"rb,storage,sampler,size,iters,compiled",
219+
"rb,storage,sampler,storage_size,data_size,iters,compiled",
214220
[
215-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
216-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
221+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
222+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
223+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
224+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
225+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
226+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
217227
],
218228
)
219-
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
229+
def test_rb_extend_sample(
230+
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
231+
):
232+
if compiled:
233+
torch._dynamo.reset_code_caches()
234+
220235
benchmark.pedantic(
221236
extend_and_sample_compiled if compiled else extend_and_sample,
222-
setup=create_tensor_rb(
237+
setup=create_compiled_tensor_rb(
223238
rb=rb,
224239
storage=storage,
225240
sampler=sampler,
226-
size=size,
241+
storage_size=storage_size,
242+
data_size=data_size,
227243
iters=iters,
244+
compilable=compiled,
228245
),
229246
iterations=1,
230247
warmup_rounds=10,

test/test_rb.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,24 @@
178178
)
179179
@pytest.mark.parametrize("size", [3, 5, 100])
180180
class TestComposableBuffers:
181-
def _get_rb(self, rb_type, size, sampler, writer, storage):
181+
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):
182182

183183
if storage is not None:
184-
storage = storage(size)
184+
storage = storage(size, compilable=compilable)
185185

186186
sampler_args = {}
187187
if sampler is samplers.PrioritizedSampler:
188188
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}
189189

190190
sampler = sampler(**sampler_args)
191-
writer = writer()
192-
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
191+
writer = writer(compilable=compilable)
192+
rb = rb_type(
193+
storage=storage,
194+
sampler=sampler,
195+
writer=writer,
196+
batch_size=3,
197+
compilable=compilable,
198+
)
193199
return rb
194200

195201
def _get_datum(self, datatype):
@@ -421,8 +427,9 @@ def data_iter():
421427
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
422428
# Our Windows CI jobs do not have "cl", so skip this test.
423429
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
430+
@pytest.mark.parametrize("avoid_max_size", [False, True])
424431
def test_extend_sample_recompile(
425-
self, rb_type, sampler, writer, storage, size, datatype
432+
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
426433
):
427434
if rb_type is not ReplayBuffer:
428435
pytest.skip(
@@ -443,28 +450,36 @@ def test_extend_sample_recompile(
443450

444451
torch._dynamo.reset_code_caches()
445452

446-
storage_size = 10 * size
453+
# Number of times to extend the replay buffer
454+
num_extend = 10
455+
data_size = size
456+
457+
# These two cases are separated because when the max storage size is
458+
# reached, the code execution path changes, causing necessary
459+
# recompiles.
460+
if avoid_max_size:
461+
storage_size = (num_extend + 1) * data_size
462+
else:
463+
storage_size = 2 * data_size
464+
447465
rb = self._get_rb(
448466
rb_type=rb_type,
449467
sampler=sampler,
450468
writer=writer,
451469
storage=storage,
452470
size=storage_size,
471+
compilable=True,
453472
)
454-
data_size = size
455473
data = self._get_data(datatype, size=data_size)
456474

457475
@torch.compile
458476
def extend_and_sample(data):
459477
rb.extend(data)
460478
return rb.sample()
461479

462-
# Number of times to extend the replay buffer
463-
num_extend = 30
464-
465-
# NOTE: The first two calls to 'extend' and 'sample' currently cause
466-
# recompilations, so avoid capturing those for now.
467-
num_extend_before_capture = 2
480+
# NOTE: The first three calls to 'extend' and 'sample' can currently
481+
# cause recompilations, so avoid capturing those.
482+
num_extend_before_capture = 3
468483

469484
for _ in range(num_extend_before_capture):
470485
extend_and_sample(data)
@@ -477,12 +492,12 @@ def extend_and_sample(data):
477492
for _ in range(num_extend - num_extend_before_capture):
478493
extend_and_sample(data)
479494

480-
assert len(rb) == storage_size
481-
assert len(records) == 0
482-
483495
finally:
484496
torch._logging.set_logs()
485497

498+
assert len(rb) == min((num_extend * data_size), storage_size)
499+
assert len(records) == 0
500+
486501
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
487502
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
488503
pytest.skip(

torchrl/_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(
296296
self.to_version = to_version
297297
self.class_method = class_method
298298
implement_for._setters.append(self)
299+
self._is_supported = None
299300

300301
@staticmethod
301302
def check_version(version: str, from_version: str | None, to_version: str | None):
@@ -304,6 +305,20 @@ def check_version(version: str, from_version: str | None, to_version: str | None
304305
to_version is None or version < parse(to_version)
305306
)
306307

308+
# If `implement_for` is used as a decorator, `torch.compile` adds guards
309+
# around it. So instead, `implement_for` can be instantiated without
310+
# decorating the function, and `implement_for.is_supported` can be called to
311+
# explicitly switch between different implementation functions.
312+
# TODO: Fix the decorator to avoid compiler guards.
313+
@torch._dynamo.assume_constant_result
314+
def is_supported(self):
315+
if self._is_supported is None:
316+
version = self.import_module(self.module_name)
317+
self._is_supported = self.check_version(
318+
version, self.from_version, self.to_version
319+
)
320+
return self._is_supported
321+
307322
@staticmethod
308323
def get_class_that_defined_method(f):
309324
"""Returns the class of a method, if it is defined, and None otherwise."""
@@ -399,6 +414,11 @@ def _lazy_call_fn(*args, **kwargs):
399414

400415
return _lazy_call_fn
401416

417+
def unsupported(self, func_name):
418+
raise ModuleNotFoundError(
419+
f"Supported version of '{func_name}' has not been found."
420+
)
421+
402422
def _call(self):
403423

404424
# If the module is missing replace the function with the mock.
@@ -408,9 +428,7 @@ def _call(self):
408428

409429
@wraps(fn)
410430
def unsupported(*args, **kwargs):
411-
raise ModuleNotFoundError(
412-
f"Supported version of '{func_name}' has not been found."
413-
)
431+
self.unsupported(func_name)
414432

415433
self.do_set = False
416434
# Return fitting implementation if it was encountered before.

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class ReplayBuffer:
132132
.. warning:: As of now, the generator has no effect on the transforms.
133133
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
134134
Defaults to ``False``.
135+
compilable (bool, optional): whether the writer is compilable.
136+
If ``True``, the writer cannot be shared between multiple processes.
137+
Defaults to ``False``.
135138
136139
Examples:
137140
>>> import torch
@@ -217,11 +220,20 @@ def __init__(
217220
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
218221
generator: torch.Generator | None = None,
219222
shared: bool = False,
223+
compilable: bool = None,
220224
) -> None:
221225
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
222226
self._storage.attach(self)
227+
if compilable is not None:
228+
self._storage._compilable = compilable
229+
self._storage._len = self._storage._len
230+
223231
self._sampler = sampler if sampler is not None else RandomSampler()
224-
self._writer = writer if writer is not None else RoundRobinWriter()
232+
self._writer = (
233+
writer
234+
if writer is not None
235+
else RoundRobinWriter(compilable=bool(compilable))
236+
)
225237
self._writer.register_storage(self._storage)
226238

227239
self._get_collate_fn(collate_fn)
@@ -600,7 +612,9 @@ def _add(self, data):
600612
return index
601613

602614
def _extend(self, data: Sequence) -> torch.Tensor:
603-
with self._replay_lock, self._write_lock:
615+
is_compiling = torch.compiler.is_dynamo_compiling()
616+
nc = contextlib.nullcontext()
617+
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
604618
if self.dim_extend > 0:
605619
data = self._transpose(data)
606620
index = self._writer.extend(data)
@@ -653,7 +667,7 @@ def update_priority(
653667

654668
@pin_memory_output
655669
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
656-
with self._replay_lock:
670+
with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext():
657671
index, info = self._sampler.sample(self._storage, batch_size)
658672
info["index"] = index
659673
data = self._storage.get(index)

0 commit comments

Comments
 (0)