Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] DataLoadingPrimer.repeat #2822

Merged
merged 4 commits into from
Mar 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
@@ -4763,6 +4763,104 @@ def policy(td):
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
assert r.ndim == 1

@pytest.mark.parametrize(
"str2str,stack_method",
[
[True, None],
[False, "as_padded_tensor"],
# TODO: a bit experimental, fails with check_env_specs
# [False, "as_nested_tensor"],
[False, None],
],
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("device", [None, "cpu"])
@pytest.mark.parametrize("batch_size", [0, 4])
@pytest.mark.parametrize("repeats", [3])
def test_llm_from_dataloader_repeats(
self, str2str, batched, stack_method, device, batch_size, repeats
):
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"example_data": "a string!",
"repeats": repeats,
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
"repeats": repeats,
}
kwargs.update({"str2str": str2str, "device": device})
env = LLMEnv.from_dataloader(**kwargs)
assert env.transform.repeats == repeats

max_steps = 3
env.append_transform(StepCounter(max_steps=max_steps))

def policy(td):
if str2str:
if not td.shape:
td["action"] = "<nothing>"
else:
td["action"] = NonTensorStack(
*["<nothing>" for _ in range(td.shape[0])]
)
else:
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
return td

if batched:
r = env.rollout(
100,
policy,
tensordict=TensorDict(batch_size=[3]),
break_when_any_done=False,
)
else:
r = env.rollout(100, policy, break_when_any_done=False)
# check that r at reset is always the same
r_reset = r[..., ::max_steps]
if not batched:
if str2str:
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
else:
assert (
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
).all()
assert (
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
).all()
assert (
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
).any()
else:
# When batched, each block contains the 3 reset packs
if str2str:
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
else:
assert (
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
).all()
assert (
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
).all()
assert (
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
).any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
5 changes: 5 additions & 0 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
@@ -142,6 +142,7 @@ def from_dataloader(
example_data: Any = None,
stack_method: Callable[[Any], Any]
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
repeats: int | None = None,
) -> LLMEnv:
"""Creates an LLMEnv instance from a dataloader.

@@ -165,6 +166,9 @@ def from_dataloader(
example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``.
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
method to use for stacking the data. Defaults to ``None``.
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
samples (rather than an advantage module).

Returns:
LLMEnv: The created LLMEnv instance.
@@ -178,6 +182,7 @@ def from_dataloader(
data_specs=data_specs,
example_data=example_data,
stack_method=stack_method,
repeats=repeats,
)
env = LLMEnv(
str2str=str2str,
29 changes: 22 additions & 7 deletions torchrl/envs/transforms/llm.py
Original file line number Diff line number Diff line change
@@ -103,6 +103,9 @@ class DataLoadingPrimer(TensorDictPrimer):
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
tensordict returned by the transform will be automatically determined assuming that there is a single batch
dimension.
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
samples (rather than an advantage module).

Attributes:
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -359,15 +362,21 @@ def __init__(
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
use_buffer: bool | None = None,
auto_batch_size: bool = True,
repeats: int | None = None,
):
self.dataloader = dataloader
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
if repeats is None:
repeats = 0
self.repeats = repeats
if (
getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None
) or repeats > 0:
use_buffer = True

self.use_buffer = use_buffer
# No auto_batch_size if we know we have a single element
self.auto_batch_size = auto_batch_size and (
getattr(dataloader, "dataloader", 1) > 0
getattr(dataloader, "batch_size", 1) > 0
)
self.endless_dataloader = self._endless_iter(self.dataloader)
if primers is None:
@@ -420,11 +429,13 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
if not reset.any():
raise RuntimeError("reset must have at least one True value.")
if reset.ndim > 0:
return self.stack_method(
[self._load_from_dataloader() for i in range(reset.sum())]
)
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
return self.stack_method(loaded)

if self.use_buffer and len(self._queue) > 0:
return self._queue.popleft()
result = self._queue.popleft()
return result

data = next(self.endless_dataloader)
# Some heuristic here:
# if data is a map, assume its keys match the keys in spec
@@ -450,7 +461,11 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
)
if self.use_buffer:
self._queue.extend(out.unbind(0))
if not out.ndim:
out = out.unsqueeze(0)
self._queue.extend(
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
)
return self._queue.popleft()
return out

6 changes: 4 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -7352,7 +7352,9 @@ def _reset(
else:
# It may be the case that reset did not provide a done state, in which case
# we fall back on the spec
done = self.parent.output_spec["full_done_spec", entry_name].zero()
done = self.parent.output_spec_unbatched[
"full_done_spec", entry_name
].zero(tensordict_reset.shape)
reset = torch.ones_like(done)

step_count = tensordict.get(step_count_key, default=None)
@@ -7362,7 +7364,7 @@ def _reset(
step_count = step_count.to(reset.device, non_blocking=True)

# zero the step count if reset is needed
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
step_count = torch.where(~reset, step_count.expand_as(reset), 0)
tensordict_reset.set(step_count_key, step_count)
if self.max_steps is not None:
truncated = step_count >= self.max_steps