Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] transformers policy #2825

Merged
merged 7 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import pytest
import torch

from tensordict import TensorDict
from tensordict import NonTensorStack, TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torch import distributions as dist, nn
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
from torchrl.data.llm.dataset import _has_transformers
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
from torchrl.modules.tensordict_module.actors import (
_process_action_space_spec,
ActorValueOperator,
Expand Down Expand Up @@ -907,6 +907,55 @@ def test_lmhead_actorvalueoperator(device):
) == len(policy_params)


@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
class TestTransformerActor:
@pytest.mark.parametrize(
"from_text, generate, tokens, attention_mask",
[
(True, True, None, None),
(True, False, None, None),
(
False,
True,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, True, torch.randint(1024, (1, 10)), None),
],
)
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
from torchrl.data.llm import LLMData
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel(GPT2Config())
tokenizer.padding_side = "left"
m = from_hf_transformers(
model, tokenizer=tokenizer, from_text=from_text, generate=generate
)
if from_text:
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
else:
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
td = m(tdin)
assert td is tdin
assert isinstance(td, LLMData)
if from_text and generate:
assert td.text_response is not None
else:
assert td.text_response is None
if attention_mask is not None or from_text:
assert td.attention_mask is not None
else:
assert td.attention_mask is None
if not generate:
assert td.text_response is None
assert td.tokens_response is None
assert td.log_probs is not None
assert td.logits is not None


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
159 changes: 113 additions & 46 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4644,11 +4644,13 @@ def __next__(self):
@pytest.mark.parametrize("batch_size", [0, 4])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
env = LLMEnv(str2str=str2str, device=device)
env = LLMEnv(
str2str=str2str, device=device, has_attention=False, no_stack=False
)
if str2str:
primer = DataLoadingPrimer(
dataloader=self.DummyDataLoader(batch_size=batch_size),
data_keys=["observation"],
data_keys=[LLMEnv._DEFAULT_STR_KEY],
example_data="a string!",
)
else:
Expand All @@ -4658,7 +4660,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
dataloader=self.DummyTensorDataLoader(
batch_size=batch_size, padding=True
),
data_keys=["observation"],
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
stack_method=stack_method,
)
Expand All @@ -4668,7 +4670,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
if batched:
td = env.reset(TensorDict(batch_size=[3]))
env.check_env_specs(break_when_any_done="both", tensordict=td)
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
else:
env.check_env_specs(break_when_any_done="both")

Expand All @@ -4691,7 +4693,7 @@ def test_llm_from_dataloader(
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
"example_data": "a string!",
}
else:
Expand All @@ -4701,11 +4703,18 @@ def test_llm_from_dataloader(
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
}
kwargs.update({"str2str": str2str, "device": device})
kwargs.update(
{
"str2str": str2str,
"device": device,
"has_attention": False,
"no_stack": False,
}
)
env = LLMEnv.from_dataloader(**kwargs)
assert not env.batch_locked
if batched:
Expand All @@ -4718,46 +4727,64 @@ def test_llm_from_dataloader(
def policy(td):
if str2str:
if not td.shape:
td["action"] = "<nothing>"
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
else:
td["action"] = NonTensorStack(
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
*["<nothing>" for _ in range(td.shape[0])]
)
else:
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
td.shape + (1,), dtype=torch.int64
)
return td

if batched:
# Tell the env that we want 3 sub-envs
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
assert r.ndim == 2
if str2str:
assert isinstance(r[0, 0]["observation"], str)
assert isinstance(r[0, 1]["observation"], str)
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
assert (
r[0, 0]["observation"]
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
]
)
assert (
r[0, 1]["observation"]
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
]
)
assert (
r[-1, 0]["observation"]
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
]
)
assert (
r[-1, 1]["observation"]
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
]
)
else:
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
assert (
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
).all()
assert (
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
).all()
assert (
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
).all()
assert (
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
).all()
else:
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
Expand All @@ -4783,7 +4810,7 @@ def test_llm_from_dataloader_repeats(
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
"example_data": "a string!",
"repeats": repeats,
}
Expand All @@ -4794,12 +4821,19 @@ def test_llm_from_dataloader_repeats(
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
"repeats": repeats,
}
kwargs.update({"str2str": str2str, "device": device})
kwargs.update(
{
"str2str": str2str,
"device": device,
"has_attention": False,
"no_stack": False,
}
)
env = LLMEnv.from_dataloader(**kwargs)
assert env.transform.repeats == repeats

Expand All @@ -4809,13 +4843,15 @@ def test_llm_from_dataloader_repeats(
def policy(td):
if str2str:
if not td.shape:
td["action"] = "<nothing>"
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
else:
td["action"] = NonTensorStack(
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
*["<nothing>" for _ in range(td.shape[0])]
)
else:
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
td.shape + (1,), dtype=torch.int64
)
return td

if batched:
Expand All @@ -4831,34 +4867,58 @@ def policy(td):
r_reset = r[..., ::max_steps]
if not batched:
if str2str:
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
assert (
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
)
assert (
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
)
assert (
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
)
else:
assert (
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
).all()
assert (
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
).all()
assert (
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
).any()
else:
# When batched, each block contains the 3 reset packs
if str2str:
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
assert (
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
)
assert (
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
)
assert (
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
)
else:
assert (
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
).all()
assert (
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
).all()
assert (
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
).any()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -4892,7 +4952,7 @@ def test_done_and_reward(
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
"example_data": "a string!",
"repeats": repeats,
"assign_reward": assign_reward,
Expand All @@ -4905,20 +4965,27 @@ def test_done_and_reward(
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
"repeats": repeats,
"assign_reward": assign_reward,
"assign_done": assign_done,
}
kwargs.update({"str2str": str2str, "device": device})
kwargs.update(
{
"str2str": str2str,
"device": device,
"has_attention": False,
"no_stack": False,
}
)
env = LLMEnv.from_dataloader(**kwargs)
# We want to make sure that transforms that rely on the done state work appropriately
env.append_transform(StepCounter(max_steps=10))

def policy(td):
td["action"] = torch.ones(
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
)
return td
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,12 @@ class LLMData(TensorClass["nocast"]):

"""

tokens: torch.Tensor
tokens: torch.Tensor | None = None
tokens_response: torch.Tensor | None = None
attention_mask: torch.Tensor | None = None
token_list: list[int] | list[list[int]] | None = None
tokens_response_list: list[list[int]] | None = None
logits: torch.Tensor | None = None
log_probs: torch.Tensor | None = None
text: str | list[str] | None = None
text_response: torch.Tensor | None = None
Loading
Loading