Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] batch_size, reward, done, attention_key in LLMEnv #2824

Merged
merged 5 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4861,6 +4861,81 @@ def policy(td):
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
).any()

@pytest.mark.parametrize(
"str2str,stack_method",
[
[True, None],
[False, "as_padded_tensor"],
],
)
@pytest.mark.parametrize("batched", [True])
@pytest.mark.parametrize("device", [None])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("repeats", [3])
@pytest.mark.parametrize(
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
)
def test_done_and_reward(
self,
str2str,
batched,
stack_method,
device,
batch_size,
repeats,
assign_reward,
assign_done,
):
with pytest.raises(
ValueError, match="str2str"
) if str2str else contextlib.nullcontext():
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"example_data": "a string!",
"repeats": repeats,
"assign_reward": assign_reward,
"assign_done": assign_done,
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
"repeats": repeats,
"assign_reward": assign_reward,
"assign_done": assign_done,
}
kwargs.update({"str2str": str2str, "device": device})
env = LLMEnv.from_dataloader(**kwargs)
# We want to make sure that transforms that rely on the done state work appropriately
env.append_transform(StepCounter(max_steps=10))

def policy(td):
td["action"] = torch.ones(
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
)
return td

if batched:
r = env.rollout(
100,
policy,
tensordict=TensorDict(batch_size=[3]),
break_when_any_done=False,
)
else:
r = env.rollout(100, policy, break_when_any_done=False)
if assign_done:
assert "terminated" in r
assert "done" in r


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
1 change: 0 additions & 1 deletion torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def __init__(
self.in_keys = query_module.in_keys
if out_keys is not None:
self.out_keys = out_keys
assert not self._has_lazy_out_keys()

self.query_module = query_module
self.index_key = query_module.index_key
Expand Down
7 changes: 4 additions & 3 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from tensordict.utils import expand_right
from torch import nn

from torchrl.objectives.value.functional import reward2go


def _get_reward(
gamma: float,
Expand Down Expand Up @@ -367,13 +365,16 @@ def __init__(
time_dim: int = 2,
discount: float = 1.0,
):
from torchrl.objectives.value.functional import reward2go

super().__init__()
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
if reward_key_out is None:
reward_key_out = reward_key
self.out_keys = [unravel_key(reward_key_out)]
self.time_dim = time_dim
self.discount = discount
self.reward2go = reward2go

def forward(self, tensordict):
# Get done
Expand All @@ -385,6 +386,6 @@ def forward(self, tensordict):
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
f"and done.shape={done.shape}."
)
reward = reward2go(reward, done, time_dim=-2, gamma=self.discount)
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
tensordict.set(("next", self.out_keys[0]), reward)
return tensordict
8 changes: 6 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset):
if reset_value is not None:
for done_key in done_key_group:
done_val = tensordict_reset.get(done_key)
if done_val[reset_value].any() and not self._allow_done_after_reset:
if (
done_val.any()
and done_val[reset_value].any()
and not self._allow_done_after_reset
):
raise RuntimeError(
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
)
Expand Down Expand Up @@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
any_done = self.any_done(tensordict)
if any_done:
return self.reset(tensordict, select_reset_only=True)
tensordict = self.reset(tensordict, select_reset_only=True)
return tensordict

def empty_cache(self):
Expand Down
Loading
Loading