Skip to content

Commit 6b0e6c3

Browse files
committed
[Feature] Tokenizer for LLMEnv
ghstack-source-id: f89207ddd0aab7c9b88f74b1000b3205b9f70b21 Pull Request resolved: pytorch/rl#2852
1 parent 619fec6 commit 6b0e6c3

File tree

5 files changed

+434
-167
lines changed

5 files changed

+434
-167
lines changed

test/test_env.py

+74-38
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import re
1515
import string
1616
from collections import defaultdict
17+
from contextlib import nullcontext
1718
from functools import partial
1819
from sys import platform
1920
from typing import Any, Optional
@@ -33,7 +34,7 @@
3334
TensorDictBase,
3435
)
3536
from tensordict.nn import TensorDictModuleBase
36-
from tensordict.tensorclass import NonTensorStack, TensorClass
37+
from tensordict.tensorclass import NonTensorData, NonTensorStack, TensorClass
3738
from tensordict.utils import _unravel_key_to_tuple
3839
from torch import nn
3940

@@ -4630,6 +4631,7 @@ def __next__(self):
46304631
else:
46314632
return tensors
46324633

4634+
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
46334635
@pytest.mark.parametrize(
46344636
"str2str,stack_method",
46354637
[
@@ -4674,22 +4676,36 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46744676
else:
46754677
env.check_env_specs(break_when_any_done="both")
46764678

4679+
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
4680+
@pytest.mark.parametrize("tokenizer", [True, False])
46774681
@pytest.mark.parametrize(
4678-
"str2str,stack_method",
4682+
"str2str,no_stack,stack_method",
46794683
[
4680-
[True, None],
4681-
[False, "as_padded_tensor"],
4682-
# TODO: a bit experimental, fails with check_env_specs
4683-
# [False, "as_nested_tensor"],
4684-
[False, None],
4684+
[True, True, None],
4685+
[True, False, None],
4686+
[False, False, "as_padded_tensor"],
4687+
[False, False, None],
46854688
],
46864689
)
46874690
@pytest.mark.parametrize("batched", [True, False])
46884691
@pytest.mark.parametrize("device", [None, "cpu"])
46894692
@pytest.mark.parametrize("batch_size", [0, 4])
46904693
def test_llm_from_dataloader(
4691-
self, str2str, batched, stack_method, device, batch_size
4694+
self,
4695+
str2str,
4696+
batched,
4697+
stack_method,
4698+
device,
4699+
batch_size,
4700+
tokenizer,
4701+
no_stack,
46924702
):
4703+
from transformers import AutoTokenizer
4704+
4705+
if tokenizer:
4706+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4707+
else:
4708+
tokenizer = None
46934709
if str2str:
46944710
kwargs = {
46954711
"dataloader": self.DummyDataLoader(batch_size=batch_size),
@@ -4712,7 +4728,8 @@ def test_llm_from_dataloader(
47124728
"str2str": str2str,
47134729
"device": device,
47144730
"has_attention": False,
4715-
"no_stack": False,
4731+
"no_stack": no_stack,
4732+
"tokenizer": tokenizer,
47164733
}
47174734
)
47184735
env = LLMEnv.from_dataloader(**kwargs)
@@ -4725,12 +4742,17 @@ def test_llm_from_dataloader(
47254742
if batch_size > 0:
47264743

47274744
def policy(td):
4728-
if str2str:
4745+
if str2str and tokenizer is None:
47294746
if not td.shape:
4730-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
4747+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData(
4748+
"<nothing>", device=device
4749+
)
47314750
else:
47324751
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
4733-
*["<nothing>" for _ in range(td.shape[0])]
4752+
*[
4753+
NonTensorData("<nothing>", device=device)
4754+
for _ in range(td.shape[0])
4755+
]
47344756
)
47354757
else:
47364758
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
@@ -4742,34 +4764,48 @@ def policy(td):
47424764
# Tell the env that we want 3 sub-envs
47434765
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47444766
assert r.ndim == 2
4745-
if str2str:
4767+
if str2str and tokenizer is None:
47464768
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
47474769
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
4748-
assert (
4749-
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4750-
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4751-
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4752-
]
4753-
)
4754-
assert (
4755-
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4756-
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4757-
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4758-
]
4759-
)
4760-
assert (
4761-
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4762-
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4763-
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4764-
]
4765-
)
4766-
assert (
4767-
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4768-
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4769-
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4770-
]
4771-
)
4772-
else:
4770+
should_fail = no_stack
4771+
if should_fail:
4772+
ctx = pytest.raises(AssertionError)
4773+
else:
4774+
ctx = nullcontext()
4775+
with ctx:
4776+
assert (
4777+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4778+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4779+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4780+
]
4781+
), (
4782+
r[0, 0][LLMEnv._DEFAULT_STR_KEY],
4783+
r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY],
4784+
r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY],
4785+
r[0, 1][LLMEnv._DEFAULT_STR_KEY],
4786+
)
4787+
with ctx:
4788+
assert (
4789+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4790+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4791+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4792+
]
4793+
)
4794+
with ctx:
4795+
assert (
4796+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4797+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4798+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4799+
]
4800+
)
4801+
with ctx:
4802+
assert (
4803+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4804+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4805+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4806+
]
4807+
)
4808+
elif tokenizer is None:
47734809
assert (
47744810
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
47754811
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]

0 commit comments

Comments
 (0)