Skip to content

Commit 2a07f4c

Browse files
authored
[Performance] Improve performance of compiled ReplayBuffer (#2529)
1 parent fa64c2f commit 2a07f4c

File tree

6 files changed

+273
-108
lines changed

6 files changed

+273
-108
lines changed

Diff for: benchmarks/test_replaybuffer_benchmark.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
173173
)
174174

175175

176-
class create_tensor_rb:
177-
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
176+
class create_compiled_tensor_rb:
177+
def __init__(
178+
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
179+
):
178180
self.storage = storage
179181
self.rb = rb
180182
self.sampler = sampler
181-
self.size = size
183+
self.storage_size = storage_size
184+
self.data_size = data_size
182185
self.iters = iters
186+
self.compilable = compilable
183187

184188
def __call__(self):
185189
kwargs = {}
186190
if self.sampler is not None:
187191
kwargs["sampler"] = self.sampler()
188192
if self.storage is not None:
189-
kwargs["storage"] = self.storage(10 * self.size)
193+
kwargs["storage"] = self.storage(
194+
self.storage_size, compilable=self.compilable
195+
)
190196

191-
rb = self.rb(batch_size=3, **kwargs)
192-
data = torch.randn(self.size, 1)
197+
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
198+
data = torch.randn(self.data_size, 1)
193199
return ((rb, data, self.iters), {})
194200

195201

@@ -210,21 +216,32 @@ def fn(td):
210216

211217

212218
@pytest.mark.parametrize(
213-
"rb,storage,sampler,size,iters,compiled",
219+
"rb,storage,sampler,storage_size,data_size,iters,compiled",
214220
[
215-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
216-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
221+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
222+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
223+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
224+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
225+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
226+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
217227
],
218228
)
219-
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
229+
def test_rb_extend_sample(
230+
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
231+
):
232+
if compiled:
233+
torch._dynamo.reset_code_caches()
234+
220235
benchmark.pedantic(
221236
extend_and_sample_compiled if compiled else extend_and_sample,
222-
setup=create_tensor_rb(
237+
setup=create_compiled_tensor_rb(
223238
rb=rb,
224239
storage=storage,
225240
sampler=sampler,
226-
size=size,
241+
storage_size=storage_size,
242+
data_size=data_size,
227243
iters=iters,
244+
compilable=compiled,
228245
),
229246
iterations=1,
230247
warmup_rounds=10,

Diff for: test/test_rb.py

+77-16
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,24 @@
178178
)
179179
@pytest.mark.parametrize("size", [3, 5, 100])
180180
class TestComposableBuffers:
181-
def _get_rb(self, rb_type, size, sampler, writer, storage):
181+
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):
182182

183183
if storage is not None:
184-
storage = storage(size)
184+
storage = storage(size, compilable=compilable)
185185

186186
sampler_args = {}
187187
if sampler is samplers.PrioritizedSampler:
188188
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}
189189

190190
sampler = sampler(**sampler_args)
191-
writer = writer()
192-
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
191+
writer = writer(compilable=compilable)
192+
rb = rb_type(
193+
storage=storage,
194+
sampler=sampler,
195+
writer=writer,
196+
batch_size=3,
197+
compilable=compilable,
198+
)
193199
return rb
194200

195201
def _get_datum(self, datatype):
@@ -421,8 +427,9 @@ def data_iter():
421427
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
422428
# Our Windows CI jobs do not have "cl", so skip this test.
423429
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
430+
@pytest.mark.parametrize("avoid_max_size", [False, True])
424431
def test_extend_sample_recompile(
425-
self, rb_type, sampler, writer, storage, size, datatype
432+
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
426433
):
427434
if rb_type is not ReplayBuffer:
428435
pytest.skip(
@@ -443,28 +450,36 @@ def test_extend_sample_recompile(
443450

444451
torch._dynamo.reset_code_caches()
445452

446-
storage_size = 10 * size
453+
# Number of times to extend the replay buffer
454+
num_extend = 10
455+
data_size = size
456+
457+
# These two cases are separated because when the max storage size is
458+
# reached, the code execution path changes, causing necessary
459+
# recompiles.
460+
if avoid_max_size:
461+
storage_size = (num_extend + 1) * data_size
462+
else:
463+
storage_size = 2 * data_size
464+
447465
rb = self._get_rb(
448466
rb_type=rb_type,
449467
sampler=sampler,
450468
writer=writer,
451469
storage=storage,
452470
size=storage_size,
471+
compilable=True,
453472
)
454-
data_size = size
455473
data = self._get_data(datatype, size=data_size)
456474

457475
@torch.compile
458476
def extend_and_sample(data):
459477
rb.extend(data)
460478
return rb.sample()
461479

462-
# Number of times to extend the replay buffer
463-
num_extend = 30
464-
465-
# NOTE: The first two calls to 'extend' and 'sample' currently cause
466-
# recompilations, so avoid capturing those for now.
467-
num_extend_before_capture = 2
480+
# NOTE: The first three calls to 'extend' and 'sample' can currently
481+
# cause recompilations, so avoid capturing those.
482+
num_extend_before_capture = 3
468483

469484
for _ in range(num_extend_before_capture):
470485
extend_and_sample(data)
@@ -477,12 +492,12 @@ def extend_and_sample(data):
477492
for _ in range(num_extend - num_extend_before_capture):
478493
extend_and_sample(data)
479494

480-
assert len(rb) == storage_size
481-
assert len(records) == 0
482-
483495
finally:
484496
torch._logging.set_logs()
485497

498+
assert len(rb) == min((num_extend * data_size), storage_size)
499+
assert len(records) == 0
500+
486501
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
487502
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
488503
pytest.skip(
@@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
806821
s = new_replay_buffer.sample()
807822
assert (s.exclude("index") == 1).all()
808823

824+
@pytest.mark.skipif(
825+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
826+
)
827+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
828+
# This test checks if the `torch._dynamo.disable` wrapper around
829+
# `TensorStorage._rand_given_ndim` is still necessary.
830+
def test__rand_given_ndim_recompile(self):
831+
torch._dynamo.reset_code_caches()
832+
833+
# Number of times to extend the replay buffer
834+
num_extend = 10
835+
data_size = 100
836+
storage_size = (num_extend + 1) * data_size
837+
sample_size = 3
838+
839+
storage = LazyTensorStorage(storage_size, compilable=True)
840+
sampler = RandomSampler()
841+
842+
# Override to avoid the `torch._dynamo.disable` wrapper
843+
storage._rand_given_ndim = storage._rand_given_ndim_impl
844+
845+
@torch.compile
846+
def extend_and_sample(data):
847+
storage.set(torch.arange(data_size) + len(storage), data)
848+
return sampler.sample(storage, sample_size)
849+
850+
data = torch.randint(100, (data_size, 1))
851+
852+
try:
853+
torch._logging.set_logs(recompiles=True)
854+
records = []
855+
capture_log_records(records, "torch._dynamo", "recompiles")
856+
857+
for _ in range(num_extend):
858+
extend_and_sample(data)
859+
860+
finally:
861+
torch._logging.set_logs()
862+
863+
assert len(storage) == num_extend * data_size
864+
assert len(records) == 8, (
865+
"If this ever decreases, that's probably good news and the "
866+
"`torch._dynamo.disable` wrapper around "
867+
"`TensorStorage._rand_given_ndim` can be removed."
868+
)
869+
809870
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
810871
def test_extend_lazystack(self, storage_type):
811872

Diff for: torchrl/_utils.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ class implement_for:
252252
Keyword Args:
253253
class_method (bool, optional): if ``True``, the function will be written as a class method.
254254
Defaults to ``False``.
255+
compilable (bool, optional): If ``False``, the module import happens
256+
only on the first call to the wrapped function. If ``True``, the
257+
module import happens when the wrapped function is initialized. This
258+
allows the wrapped function to work well with ``torch.compile``.
259+
Defaults to ``False``.
255260
256261
Examples:
257262
>>> @implement_for("gym", "0.13", "0.14")
@@ -290,11 +295,13 @@ def __init__(
290295
to_version: str = None,
291296
*,
292297
class_method: bool = False,
298+
compilable: bool = False,
293299
):
294300
self.module_name = module_name
295301
self.from_version = from_version
296302
self.to_version = to_version
297303
self.class_method = class_method
304+
self._compilable = compilable
298305
implement_for._setters.append(self)
299306

300307
@staticmethod
@@ -386,18 +393,27 @@ def __call__(self, fn):
386393
self.fn = fn
387394
implement_for._lazy_impl[self.func_name].append(self._call)
388395

389-
@wraps(fn)
390-
def _lazy_call_fn(*args, **kwargs):
391-
# first time we call the function, we also do the replacement.
392-
# This will cause the imports to occur only during the first call to fn
396+
if self._compilable:
397+
_call_fn = self._delazify(self.func_name)
393398

394-
result = self._delazify(self.func_name)(*args, **kwargs)
395-
return result
399+
if self.class_method:
400+
return classmethod(_call_fn)
396401

397-
if self.class_method:
398-
return classmethod(_lazy_call_fn)
402+
return _call_fn
403+
else:
404+
405+
@wraps(fn)
406+
def _lazy_call_fn(*args, **kwargs):
407+
# first time we call the function, we also do the replacement.
408+
# This will cause the imports to occur only during the first call to fn
409+
410+
result = self._delazify(self.func_name)(*args, **kwargs)
411+
return result
412+
413+
if self.class_method:
414+
return classmethod(_lazy_call_fn)
399415

400-
return _lazy_call_fn
416+
return _lazy_call_fn
401417

402418
def _call(self):
403419

Diff for: torchrl/data/replay_buffers/replay_buffers.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
import torch
2121

22+
try:
23+
from torch.compiler import is_dynamo_compiling
24+
except ImportError:
25+
from torch._dynamo import is_compiling as is_dynamo_compiling
26+
2227
from tensordict import (
2328
is_tensor_collection,
2429
is_tensorclass,
@@ -132,6 +137,9 @@ class ReplayBuffer:
132137
.. warning:: As of now, the generator has no effect on the transforms.
133138
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
134139
Defaults to ``False``.
140+
compilable (bool, optional): whether the writer is compilable.
141+
If ``True``, the writer cannot be shared between multiple processes.
142+
Defaults to ``False``.
135143
136144
Examples:
137145
>>> import torch
@@ -217,11 +225,20 @@ def __init__(
217225
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
218226
generator: torch.Generator | None = None,
219227
shared: bool = False,
228+
compilable: bool = None,
220229
) -> None:
221-
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
230+
self._storage = (
231+
storage
232+
if storage is not None
233+
else ListStorage(max_size=1_000, compilable=compilable)
234+
)
222235
self._storage.attach(self)
223236
self._sampler = sampler if sampler is not None else RandomSampler()
224-
self._writer = writer if writer is not None else RoundRobinWriter()
237+
self._writer = (
238+
writer
239+
if writer is not None
240+
else RoundRobinWriter(compilable=bool(compilable))
241+
)
225242
self._writer.register_storage(self._storage)
226243

227244
self._get_collate_fn(collate_fn)
@@ -600,7 +617,9 @@ def _add(self, data):
600617
return index
601618

602619
def _extend(self, data: Sequence) -> torch.Tensor:
603-
with self._replay_lock, self._write_lock:
620+
is_compiling = is_dynamo_compiling()
621+
nc = contextlib.nullcontext()
622+
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
604623
if self.dim_extend > 0:
605624
data = self._transpose(data)
606625
index = self._writer.extend(data)
@@ -653,7 +672,7 @@ def update_priority(
653672

654673
@pin_memory_output
655674
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
656-
with self._replay_lock:
675+
with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext():
657676
index, info = self._sampler.sample(self._storage, batch_size)
658677
info["index"] = index
659678
data = self._storage.get(index)

0 commit comments

Comments
 (0)