Skip to content

Commit f4713f9

Browse files
committed
[Feature] DataLoadingPrimer handling of dataloader with batch-size > 0
ghstack-source-id: cf1942ece8dfbd6506f91939561df7443bd840ab Pull Request resolved: #2821
1 parent 40b147e commit f4713f9

File tree

3 files changed

+126
-13
lines changed

3 files changed

+126
-13
lines changed

test/test_env.py

+65-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TensorDictBase,
3434
)
3535
from tensordict.nn import TensorDictModuleBase
36+
from tensordict.tensorclass import NonTensorStack
3637
from tensordict.utils import _unravel_key_to_tuple
3738
from torch import nn
3839

@@ -4577,20 +4578,23 @@ def __next__(self):
45774578
],
45784579
)
45794580
@pytest.mark.parametrize("batched", [True, False])
4581+
@pytest.mark.parametrize("batch_size", [0, 4])
45804582
@pytest.mark.parametrize("device", [None, "cpu"])
4581-
def test_llm_env(self, str2str, batched, stack_method, device):
4583+
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
45824584
env = LLMEnv(str2str=str2str, device=device)
45834585
if str2str:
45844586
primer = DataLoadingPrimer(
4585-
dataloader=self.DummyDataLoader(),
4587+
dataloader=self.DummyDataLoader(batch_size=batch_size),
45864588
data_keys=["observation"],
45874589
example_data="a string!",
45884590
)
45894591
else:
45904592
if stack_method is None:
45914593
stack_method = as_padded_tensor
45924594
primer = DataLoadingPrimer(
4593-
dataloader=self.DummyTensorDataLoader(padding=True),
4595+
dataloader=self.DummyTensorDataLoader(
4596+
batch_size=batch_size, padding=True
4597+
),
45944598
data_keys=["observation"],
45954599
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
45964600
stack_method=stack_method,
@@ -4601,6 +4605,7 @@ def test_llm_env(self, str2str, batched, stack_method, device):
46014605
if batched:
46024606
td = env.reset(TensorDict(batch_size=[3]))
46034607
env.check_env_specs(break_when_any_done="both", tensordict=td)
4608+
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46044609
else:
46054610
env.check_env_specs(break_when_any_done="both")
46064611

@@ -4616,18 +4621,23 @@ def test_llm_env(self, str2str, batched, stack_method, device):
46164621
)
46174622
@pytest.mark.parametrize("batched", [True, False])
46184623
@pytest.mark.parametrize("device", [None, "cpu"])
4619-
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
4624+
@pytest.mark.parametrize("batch_size", [0, 4])
4625+
def test_llm_from_dataloader(
4626+
self, str2str, batched, stack_method, device, batch_size
4627+
):
46204628
if str2str:
46214629
kwargs = {
4622-
"dataloader": self.DummyDataLoader(),
4630+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
46234631
"data_keys": ["observation"],
46244632
"example_data": "a string!",
46254633
}
46264634
else:
46274635
if stack_method is None:
46284636
stack_method = as_padded_tensor
46294637
kwargs = {
4630-
"dataloader": self.DummyTensorDataLoader(padding=True),
4638+
"dataloader": self.DummyTensorDataLoader(
4639+
padding=True, batch_size=batch_size
4640+
),
46314641
"data_keys": ["observation"],
46324642
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
46334643
"stack_method": stack_method,
@@ -4640,6 +4650,55 @@ def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
46404650
env.check_env_specs(break_when_any_done="both", tensordict=td)
46414651
else:
46424652
env.check_env_specs(break_when_any_done="both")
4653+
if batch_size > 0:
4654+
4655+
def policy(td):
4656+
if str2str:
4657+
if not td.shape:
4658+
td["action"] = "<nothing>"
4659+
else:
4660+
td["action"] = NonTensorStack(
4661+
*["<nothing>" for _ in range(td.shape[0])]
4662+
)
4663+
else:
4664+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4665+
return td
4666+
4667+
if batched:
4668+
# Tell the env that we want 3 sub-envs
4669+
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
4670+
assert r.ndim == 2
4671+
if str2str:
4672+
assert isinstance(r[0, 0]["observation"], str)
4673+
assert isinstance(r[0, 1]["observation"], str)
4674+
assert (
4675+
r[0, 0]["observation"]
4676+
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4677+
)
4678+
assert (
4679+
r[0, 1]["observation"]
4680+
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4681+
)
4682+
assert (
4683+
r[-1, 0]["observation"]
4684+
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4685+
)
4686+
assert (
4687+
r[-1, 1]["observation"]
4688+
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4689+
)
4690+
else:
4691+
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4692+
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
4693+
assert (
4694+
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4695+
).all()
4696+
assert (
4697+
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4698+
).all()
4699+
else:
4700+
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
4701+
assert r.ndim == 1
46434702

46444703

46454704
if __name__ == "__main__":

torchrl/envs/custom/llm.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ def from_dataloader(
188188
)
189189
return env.append_transform(primer)
190190

191+
@staticmethod
192+
def _check_obs_act_and_cat(obs, action):
193+
if not isinstance(obs, str):
194+
raise TypeError(f"Observation must be a string, got {type(obs)}.")
195+
if not isinstance(action, str):
196+
raise TypeError(f"Action must be a string, got {type(action)}.")
197+
return obs + action
198+
191199
def _step(
192200
self,
193201
tensordict: TensorDictBase,
@@ -202,11 +210,14 @@ def _step(
202210
"The tensordict is batchless, yet the action and/or observations are not "
203211
f"strings but {type(action)} and {type(obs)}, respectivly."
204212
)
205-
observation = obs + action
213+
observation = self._check_obs_act_and_cat(obs, action)
206214
else:
207-
observation = [
208-
_obs + _action for (_obs, _action) in _zip_strict(obs, action)
209-
]
215+
observation = NonTensorStack(
216+
*[
217+
self._check_obs_act_and_cat(_obs, _action)
218+
for (_obs, _action) in _zip_strict(obs, action)
219+
]
220+
)
210221
else:
211222
try:
212223
obs: torch.Tensor = tensordict.get(self.observation_key)

torchrl/envs/transforms/rlhf.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
from collections import deque
78
from collections.abc import Mapping
89
from copy import copy, deepcopy
910
from typing import Any, Callable, Iterable, Literal
@@ -87,11 +88,21 @@ class DataLoadingPrimer(TensorDictPrimer):
8788
8889
Args:
8990
dataloader (Iterable[Any]): The dataloader to load data from.
91+
92+
Keyword Args:
9093
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
9194
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
9295
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
9396
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
9497
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
98+
use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size
99+
that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data
100+
ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset
101+
are loaded in order.
102+
Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1.
103+
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
104+
tensordict returned by the transform will be automatically determined assuming that there is a single batch
105+
dimension.
95106
96107
Attributes:
97108
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -339,14 +350,25 @@ class DataLoadingPrimer(TensorDictPrimer):
339350
def __init__(
340351
self,
341352
dataloader: Iterable[Any],
353+
*,
342354
primers: Composite | None = None,
343355
data_keys: list[NestedKey] | None = None,
344356
data_specs: list[TensorSpec] | None = None,
345357
example_data: Any = None,
346358
stack_method: Callable[[Any], Any]
347359
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
360+
use_buffer: bool | None = None,
361+
auto_batch_size: bool = True,
348362
):
349363
self.dataloader = dataloader
364+
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
365+
use_buffer = True
366+
367+
self.use_buffer = use_buffer
368+
# No auto_batch_size if we know we have a single element
369+
self.auto_batch_size = auto_batch_size and (
370+
getattr(dataloader, "dataloader", 1) > 0
371+
)
350372
self.endless_dataloader = self._endless_iter(self.dataloader)
351373
if primers is None:
352374
if data_keys is None:
@@ -381,34 +403,55 @@ def __init__(
381403
single_default_value=True,
382404
call_before_env_reset=True,
383405
)
406+
if self.use_buffer:
407+
self._queue = deque()
384408

385409
@classmethod
386410
def _endless_iter(self, obj):
387411
while True:
388412
yield from obj
389413

390414
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
415+
"""Loads a single element from the dataloader, or alternatively from the buffer.
416+
417+
If `reset` is passed, the one element per reset will be loaded.
418+
"""
391419
if reset is not None:
392420
if not reset.any():
393421
raise RuntimeError("reset must have at least one True value.")
394422
if reset.ndim > 0:
395423
return self.stack_method(
396424
[self._load_from_dataloader() for i in range(reset.sum())]
397425
)
426+
if self.use_buffer and len(self._queue) > 0:
427+
return self._queue.popleft()
398428
data = next(self.endless_dataloader)
399429
# Some heuristic here:
400430
# if data is a map, assume its keys match the keys in spec
401431
# TODO: one could rename the keys too
402432
if isinstance(data, Mapping):
403-
out = TensorDict(data)
433+
out = TensorDict.from_dict(
434+
data, auto_batch_size=self.auto_batch_size, batch_dims=1
435+
)
404436
elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)):
405-
out = TensorDict({k: val for k, val in _zip_strict(self.data_keys, data)})
437+
out = TensorDict.from_dict(
438+
{k: val for k, val in _zip_strict(self.data_keys, data)},
439+
auto_batch_size=self.auto_batch_size,
440+
batch_dims=1,
441+
)
406442
elif len(self.data_keys) == 1:
407-
out = TensorDict({self.data_keys[0]: data})
443+
out = TensorDict.from_dict(
444+
{self.data_keys[0]: data},
445+
auto_batch_size=self.auto_batch_size,
446+
batch_dims=1,
447+
)
408448
else:
409449
raise ValueError(
410450
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
411451
)
452+
if self.use_buffer:
453+
self._queue.extend(out.unbind(0))
454+
return self._queue.popleft()
412455
return out
413456

414457

0 commit comments

Comments
 (0)