Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 06a8c86

Browse files
committedMar 20, 2025
[Feature] Tokenizer for LLMEnv
ghstack-source-id: 429b8d03baa4ce0201451312f8d17de9d024fca8 Pull Request resolved: #2852
1 parent 619fec6 commit 06a8c86

File tree

5 files changed

+438
-178
lines changed

5 files changed

+438
-178
lines changed
 

‎test/test_env.py

+74-38
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import re
1515
import string
1616
from collections import defaultdict
17+
from contextlib import nullcontext
1718
from functools import partial
1819
from sys import platform
1920
from typing import Any, Optional
@@ -33,7 +34,7 @@
3334
TensorDictBase,
3435
)
3536
from tensordict.nn import TensorDictModuleBase
36-
from tensordict.tensorclass import NonTensorStack, TensorClass
37+
from tensordict.tensorclass import NonTensorData, NonTensorStack, TensorClass
3738
from tensordict.utils import _unravel_key_to_tuple
3839
from torch import nn
3940

@@ -4630,6 +4631,7 @@ def __next__(self):
46304631
else:
46314632
return tensors
46324633

4634+
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
46334635
@pytest.mark.parametrize(
46344636
"str2str,stack_method",
46354637
[
@@ -4674,22 +4676,36 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46744676
else:
46754677
env.check_env_specs(break_when_any_done="both")
46764678

4679+
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
4680+
@pytest.mark.parametrize("tokenizer", [True, False])
46774681
@pytest.mark.parametrize(
4678-
"str2str,stack_method",
4682+
"str2str,no_stack,stack_method",
46794683
[
4680-
[True, None],
4681-
[False, "as_padded_tensor"],
4682-
# TODO: a bit experimental, fails with check_env_specs
4683-
# [False, "as_nested_tensor"],
4684-
[False, None],
4684+
[True, True, None],
4685+
[True, False, None],
4686+
[False, False, "as_padded_tensor"],
4687+
[False, False, None],
46854688
],
46864689
)
46874690
@pytest.mark.parametrize("batched", [True, False])
46884691
@pytest.mark.parametrize("device", [None, "cpu"])
46894692
@pytest.mark.parametrize("batch_size", [0, 4])
46904693
def test_llm_from_dataloader(
4691-
self, str2str, batched, stack_method, device, batch_size
4694+
self,
4695+
str2str,
4696+
batched,
4697+
stack_method,
4698+
device,
4699+
batch_size,
4700+
tokenizer,
4701+
no_stack,
46924702
):
4703+
from transformers import AutoTokenizer
4704+
4705+
if tokenizer:
4706+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4707+
else:
4708+
tokenizer = None
46934709
if str2str:
46944710
kwargs = {
46954711
"dataloader": self.DummyDataLoader(batch_size=batch_size),
@@ -4712,7 +4728,8 @@ def test_llm_from_dataloader(
47124728
"str2str": str2str,
47134729
"device": device,
47144730
"has_attention": False,
4715-
"no_stack": False,
4731+
"no_stack": no_stack,
4732+
"tokenizer": tokenizer,
47164733
}
47174734
)
47184735
env = LLMEnv.from_dataloader(**kwargs)
@@ -4725,12 +4742,17 @@ def test_llm_from_dataloader(
47254742
if batch_size > 0:
47264743

47274744
def policy(td):
4728-
if str2str:
4745+
if str2str and tokenizer is None:
47294746
if not td.shape:
4730-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
4747+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData(
4748+
"<nothing>", device=device
4749+
)
47314750
else:
47324751
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
4733-
*["<nothing>" for _ in range(td.shape[0])]
4752+
*[
4753+
NonTensorData("<nothing>", device=device)
4754+
for _ in range(td.shape[0])
4755+
]
47344756
)
47354757
else:
47364758
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
@@ -4742,34 +4764,48 @@ def policy(td):
47424764
# Tell the env that we want 3 sub-envs
47434765
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47444766
assert r.ndim == 2
4745-
if str2str:
4767+
if str2str and tokenizer is None:
47464768
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
47474769
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
4748-
assert (
4749-
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4750-
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4751-
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4752-
]
4753-
)
4754-
assert (
4755-
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4756-
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4757-
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4758-
]
4759-
)
4760-
assert (
4761-
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4762-
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4763-
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4764-
]
4765-
)
4766-
assert (
4767-
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4768-
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4769-
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4770-
]
4771-
)
4772-
else:
4770+
should_fail = no_stack
4771+
if should_fail:
4772+
ctx = pytest.raises(AssertionError)
4773+
else:
4774+
ctx = nullcontext()
4775+
with ctx:
4776+
assert (
4777+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4778+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4779+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4780+
]
4781+
), (
4782+
r[0, 0][LLMEnv._DEFAULT_STR_KEY],
4783+
r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY],
4784+
r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY],
4785+
r[0, 1][LLMEnv._DEFAULT_STR_KEY],
4786+
)
4787+
with ctx:
4788+
assert (
4789+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4790+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4791+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4792+
]
4793+
)
4794+
with ctx:
4795+
assert (
4796+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4797+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4798+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4799+
]
4800+
)
4801+
with ctx:
4802+
assert (
4803+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4804+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4805+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4806+
]
4807+
)
4808+
elif tokenizer is None:
47734809
assert (
47744810
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
47754811
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]

‎torchrl/envs/custom/llm.py

+109-27
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88

99
import torch
1010

11-
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
11+
from tensordict import (
12+
is_leaf_nontensor,
13+
NestedKey,
14+
TensorDict,
15+
TensorDictBase,
16+
unravel_key,
17+
)
1218
from tensordict.tensorclass import NonTensorData, NonTensorStack
1319
from tensordict.utils import _zip_strict
1420
from torch.utils.data import DataLoader
21+
22+
from torchrl._utils import _replace_last
1523
from torchrl.data.map.hash import SipHash
1624
from torchrl.data.tensor_specs import (
1725
Bounded,
@@ -38,7 +46,8 @@ class LLMEnv(EnvBase):
3846
3947
Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt.
4048
41-
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader`
49+
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via
50+
:meth:`~from_dataloader`.
4251
4352
Keyword Args:
4453
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`).
@@ -145,12 +154,19 @@ def __init__(
145154
self.full_observation_spec_unbatched = Composite(
146155
{
147156
self.str_key: NonTensor(
148-
example_data="a string", batched=True, shape=()
157+
example_data="a string",
158+
batched=True,
159+
shape=(),
160+
device=device,
149161
)
150162
}
151163
)
152164
self.full_action_spec_unbatched = Composite(
153-
{action_key: NonTensor(example_data="a string", batched=True, shape=())}
165+
{
166+
action_key: NonTensor(
167+
example_data="a string", batched=True, shape=(), device=device
168+
)
169+
}
154170
)
155171
else:
156172
if vocab_size is None:
@@ -208,27 +224,28 @@ def __init__(
208224
if not self.assign_done:
209225
# Use single done
210226
self.full_done_spec_unbatched = Composite(
211-
done=Unbounded(shape=(1,), dtype=torch.bool),
212-
terminated=Unbounded(shape=(1,), dtype=torch.bool),
227+
done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
228+
terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
213229
)
214230
elif self.str2str:
215231
raise STR2STR_ERR
216232
else:
217233
# Use single done
218234
self.full_done_spec_unbatched = Composite(
219235
tokens_data=Composite(
220-
done=Unbounded(shape=(-1,), dtype=torch.bool),
221-
terminated=Unbounded(shape=(-1,), dtype=torch.bool),
236+
done=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
237+
terminated=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
222238
),
223-
done=Unbounded(shape=(1,), dtype=torch.bool),
224-
terminated=Unbounded(shape=(1,), dtype=torch.bool),
239+
done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
240+
terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
225241
)
226242

227243
@classmethod
228244
def from_dataloader(
229245
cls,
230246
dataloader: DataLoader,
231247
*,
248+
tokenizer: transformers.PretrainedTokenizerBase | None = None, # noqa
232249
token_key: NestedKey | None = None,
233250
str_key: NestedKey | None = None,
234251
attention_key: NestedKey | None = None,
@@ -257,6 +274,18 @@ def from_dataloader(
257274
258275
Args:
259276
dataloader (DataLoader): The dataloader to load data from.
277+
278+
Keyword Args:
279+
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
280+
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
281+
pre-trained tokenizer.
282+
283+
.. note:: Using the `tokenizer` will append a :class:`~torchrl.envs.Tokenizer` transform to the environment.
284+
If `str2str` is set to `True`, the tokenizer will be called during every iteration and the rollout
285+
will contain both tokens and text data.
286+
If `str2str` is set to `False`, the tokenizer will be called during reset only, and the only
287+
text data in the rollout will be the text sampled from the dataset.
288+
260289
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`).
261290
Defaults to ``("tokens_in", "input_ids")``.
262291
str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`).
@@ -305,19 +334,54 @@ def from_dataloader(
305334
Returns:
306335
LLMEnv: The created LLMEnv instance.
307336
"""
308-
from torchrl.envs import DataLoadingPrimer
337+
from torchrl.envs import DataLoadingPrimer, Tokenizer
338+
339+
if str_key is None:
340+
str_key = LLMEnv._DEFAULT_STR_KEY
341+
if token_key is None:
342+
token_key = LLMEnv._DEFAULT_TOKEN_KEY
343+
if attention_key is None:
344+
attention_key = LLMEnv._DEFAULT_ATTENTION_KEY
345+
elif tokenizer is not None and attention_key != _replace_last(
346+
token_key, "attention_mask"
347+
):
348+
raise ValueError(
349+
"When using the Tokenizer, attention key must match `(*token_key[:-1], 'attention_mask')` where "
350+
f"`token_key` is a tuple-typed nested key. Got attention_key={attention_key} while expecting "
351+
f"{_replace_last(token_key, 'attention_mask')}."
352+
)
353+
354+
if tokenizer is not None:
355+
if str2str:
356+
# In this case, the tokenizer is appended to the env after each step
357+
if action_key is None:
358+
action_key = cls._DEFAULT_ACTION_STR_KEY
359+
tokenizer_transform = Tokenizer(
360+
tokenizer=tokenizer,
361+
in_keys=[str_key],
362+
out_keys=[token_key],
363+
# Assume that the tokens are named according to _DEFAULT_ACTION_TOKENS_KEY
364+
in_keys_inv=[action_key],
365+
out_keys_inv=[cls._DEFAULT_ACTION_TOKENS_KEY],
366+
call_before_reset=False,
367+
# We should always see the required entries
368+
missing_tolerance=False,
369+
)
370+
else:
371+
# In this case, the tokenizer acts before reset and that's all
372+
tokenizer_transform = Tokenizer(
373+
tokenizer=tokenizer,
374+
in_keys=[str_key],
375+
out_keys=[token_key],
376+
call_before_reset=True,
377+
missing_tolerance=True,
378+
)
309379

310380
if data_keys is None:
311381
if str2str:
312-
if str_key is None:
313-
data_keys = [LLMEnv._DEFAULT_STR_KEY]
314-
else:
315-
data_keys = [str_key]
382+
data_keys = [str_key]
316383
else:
317-
if token_key is None:
318-
data_keys = [LLMEnv._DEFAULT_TOKEN_KEY]
319-
else:
320-
data_keys = [token_key]
384+
data_keys = [token_key]
321385
if has_attention:
322386
if attention_key is None:
323387
data_keys.append(LLMEnv._DEFAULT_ATTENTION_KEY)
@@ -332,6 +396,7 @@ def from_dataloader(
332396
example_data=example_data,
333397
stack_method=stack_method,
334398
repeats=repeats,
399+
device=device,
335400
)
336401
env = LLMEnv(
337402
str2str=str2str,
@@ -349,15 +414,17 @@ def from_dataloader(
349414
has_attention=has_attention,
350415
as_llm_data=as_llm_data,
351416
)
417+
if tokenizer is not None:
418+
env = env.append_transform(tokenizer_transform)
352419
return env.append_transform(primer)
353420

354421
@staticmethod
355-
def _check_obs_act_and_cat(obs, action):
422+
def _check_obs_act_and_cat(obs, action, *, device):
356423
if not isinstance(obs, str):
357424
raise TypeError(f"Observation must be a string, got {type(obs)}.")
358425
if not isinstance(action, str):
359426
raise TypeError(f"Action must be a string, got {type(action)}.")
360-
return obs + action
427+
return NonTensorData(obs + action, device=device)
361428

362429
def _step(
363430
self,
@@ -409,10 +476,11 @@ def _make_next_obs(
409476
self, tensordict: TensorDictBase, nex_td: TensorDictBase
410477
) -> TensorDictBase:
411478
if self.no_stack:
412-
if self.str2str:
413-
raise NotImplementedError
414479
action = tensordict.get(self.action_key)
415-
nex_td.set(self.token_key, action)
480+
if self.str2str:
481+
nex_td.set(self.str_key, action)
482+
else:
483+
nex_td.set(self.token_key, action)
416484
if self.has_attention:
417485
attention_mask = tensordict.get(self.attention_key)
418486
n = action.shape[-1] - attention_mask.shape[-1]
@@ -438,11 +506,13 @@ def _make_next_obs(
438506
"The tensordict is batchless, yet the action and/or observations are not "
439507
f"strings but {type(action)} and {type(obs)}, respectivly."
440508
)
441-
observation = self._check_obs_act_and_cat(obs, action)
509+
observation = self._check_obs_act_and_cat(
510+
obs, action, device=self.device
511+
)
442512
else:
443513
observation = NonTensorStack(
444514
*[
445-
self._check_obs_act_and_cat(_obs, _action)
515+
self._check_obs_act_and_cat(_obs, _action, device=self.device)
446516
for (_obs, _action) in _zip_strict(obs, action)
447517
]
448518
)
@@ -463,6 +533,12 @@ def _make_next_obs(
463533
)
464534
else:
465535
observation = torch.cat([obs, action], -1)
536+
if self.has_attention:
537+
attention_mask = tensordict.get(self.attention_key)
538+
attention_mask = torch.cat(
539+
[attention_mask, attention_mask.new_ones(action.shape)], -1
540+
)
541+
nex_td.set(self.attention_key, attention_mask)
466542
except TypeError:
467543
raise TypeError(
468544
"Failed to cat action and observation tensors. Check that str2str argument is correctly "
@@ -484,10 +560,16 @@ def check_str():
484560

485561
if tensordict is None or check_token() or check_str():
486562
raise KeyError(
487-
f"Observation key {self.token_key} is not defined. Make sure a TensorDictPrimer (eg, "
563+
f"Observation key {self.token_key}/{self.str_key} is not defined in tensordict with keys "
564+
f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
488565
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
489566
)
490567
td_reset = tensordict.copy()
568+
if td_reset.device != self.device:
569+
if self.device is None:
570+
td_reset.clear_device_()
571+
else:
572+
td_reset = td_reset.to(self.device)
491573
tensordict = self._maybe_make_done(tensordict, td_reset)
492574
if self.as_llm_data:
493575
raise NotImplementedError()

‎torchrl/envs/transforms/llm.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@
1010
from typing import Any, Callable, Iterable, Literal
1111

1212
import torch
13-
from tensordict import (
14-
maybe_dense_stack,
15-
NestedKey,
16-
TensorDict,
17-
TensorDictBase,
18-
unravel_key,
19-
)
13+
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key
2014
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
2115
from tensordict.utils import _zip_strict, is_seq_of_nested_key
2216
from torch import nn
@@ -364,6 +358,7 @@ def __init__(
364358
use_buffer: bool | None = None,
365359
auto_batch_size: bool = True,
366360
repeats: int | None = None,
361+
device: torch.device | None = None,
367362
):
368363
self.dataloader = dataloader
369364
if repeats is None:
@@ -385,7 +380,7 @@ def __init__(
385380
self.endless_dataloader = self._endless_iter(self.dataloader)
386381

387382
if stack_method is None:
388-
stack_method = maybe_dense_stack
383+
stack_method = lazy_stack
389384
elif stack_method == "as_nested_tensor":
390385
stack_method = as_nested_tensor
391386
elif stack_method == "as_padded_tensor":
@@ -424,6 +419,7 @@ def __init__(
424419
expand_specs=None,
425420
single_default_value=True,
426421
call_before_env_reset=True,
422+
device=device,
427423
)
428424
self._reset_key = "_reset"
429425

@@ -432,10 +428,14 @@ def _endless_iter(self, obj):
432428
while True:
433429
yield from obj
434430

431+
# def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
432+
# td = super()._reset_env_preprocess(tensordict)
433+
# return lazy_stack(list(td.unbind(0)))
434+
#
435435
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
436436
"""Loads a single element from the dataloader, or alternatively from the buffer.
437437
438-
If `reset` is passed, the one element per reset will be loaded.
438+
If `reset` is passed, then one element per reset will be loaded.
439439
"""
440440
if reset is not None:
441441
if not reset.any():
@@ -444,8 +444,16 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
444444
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
445445
return self.stack_method(loaded)
446446

447+
primers = getattr(self, "primers", None)
448+
if primers is not None:
449+
device = self.primers.device
450+
else:
451+
device = None
452+
447453
if self.use_buffer and len(self._queue) > 0:
448454
result = self._queue.popleft()
455+
if result.device != device:
456+
result = result.to(device)
449457
return result
450458

451459
data = next(self.endless_dataloader)
@@ -454,7 +462,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
454462
# TODO: one could rename the keys too
455463
if isinstance(data, Mapping):
456464
out = TensorDict.from_dict(
457-
data, auto_batch_size=self.auto_batch_size, batch_dims=1
465+
data,
466+
auto_batch_size=self.auto_batch_size,
467+
batch_dims=1,
468+
device=device,
458469
)
459470
elif self.data_keys is None:
460471
raise RuntimeError(
@@ -467,12 +478,14 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
467478
{k: val for k, val in _zip_strict(self.data_keys, data)},
468479
auto_batch_size=self.auto_batch_size,
469480
batch_dims=1,
481+
device=device,
470482
)
471483
elif len(self.data_keys) == 1:
472484
out = TensorDict.from_dict(
473485
{self.data_keys[0]: data},
474486
auto_batch_size=self.auto_batch_size,
475487
batch_dims=1,
488+
device=device,
476489
)
477490
else:
478491
raise ValueError(

‎torchrl/envs/transforms/transforms.py

+175-72
Original file line numberDiff line numberDiff line change
@@ -5313,8 +5313,8 @@ class Tokenizer(UnaryTransform):
53135313

53145314
def __init__(
53155315
self,
5316-
in_keys: Sequence[NestedKey],
5317-
out_keys: Sequence[NestedKey],
5316+
in_keys: Sequence[NestedKey] | None = None,
5317+
out_keys: Sequence[NestedKey] | None = None,
53185318
in_keys_inv: Sequence[NestedKey] | None = None,
53195319
out_keys_inv: Sequence[NestedKey] | None = None,
53205320
*,
@@ -5325,6 +5325,9 @@ def __init__(
53255325
add_special_tokens: bool = False,
53265326
padding: bool = True,
53275327
max_length: int | None = None,
5328+
return_attention_mask: bool = True,
5329+
missing_tolerance: bool = True,
5330+
call_before_reset: bool = False,
53285331
):
53295332
if tokenizer is None:
53305333
from transformers import AutoTokenizer
@@ -5340,6 +5343,8 @@ def __init__(
53405343
self.skip_special_tokens = skip_special_tokens
53415344
self.padding = padding
53425345
self.max_length = max_length
5346+
self.return_attention_mask = return_attention_mask
5347+
self.call_before_reset = call_before_reset
53435348
if additional_tokens:
53445349
self.tokenizer.add_tokens(additional_tokens)
53455350
super().__init__(
@@ -5351,6 +5356,7 @@ def __init__(
53515356
inv_fn=self.call_tokenizer_inv_fn,
53525357
use_raw_nontensor=use_raw_nontensor,
53535358
)
5359+
self._missing_tolerance = missing_tolerance
53545360

53555361
@property
53565362
def device(self):
@@ -5363,6 +5369,68 @@ def device(self):
53635369
self._device = device
53645370
return device
53655371

5372+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
5373+
# Specialized for attention mask
5374+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
5375+
value = next_tensordict.get(in_key, default=None)
5376+
if value is not None:
5377+
observation = self._apply_transform(value)
5378+
if self.return_attention_mask:
5379+
observation, attention_mask = observation
5380+
next_tensordict.set(
5381+
_replace_last(out_key, "attention_mask"),
5382+
attention_mask,
5383+
)
5384+
next_tensordict.set(
5385+
out_key,
5386+
observation,
5387+
)
5388+
elif (
5389+
self.missing_tolerance
5390+
and self.return_attention_mask
5391+
and out_key in next_tensordict.keys(True)
5392+
):
5393+
attention_key = _replace_last(out_key, "attention_mask")
5394+
if attention_key not in next_tensordict:
5395+
next_tensordict[attention_key] = torch.ones_like(
5396+
next_tensordict.get(out_key)
5397+
)
5398+
elif not self.missing_tolerance:
5399+
raise KeyError(
5400+
f"{self}: '{in_key}' not found in tensordict {next_tensordict}"
5401+
)
5402+
return next_tensordict
5403+
5404+
@dispatch(source="in_keys", dest="out_keys")
5405+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
5406+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
5407+
data = tensordict.get(in_key, None)
5408+
if data is not None:
5409+
data = self._apply_transform(data)
5410+
if self.return_attention_mask:
5411+
data, attention_mask = data
5412+
tensordict.set(
5413+
_replace_last(out_key, "attention_mask"),
5414+
attention_mask,
5415+
)
5416+
tensordict.set(out_key, data)
5417+
elif not self.missing_tolerance:
5418+
raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
5419+
return tensordict
5420+
5421+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
5422+
if self.call_before_reset:
5423+
with _set_missing_tolerance(self, True):
5424+
tensordict = self._call(tensordict)
5425+
return tensordict
5426+
5427+
def _reset(
5428+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
5429+
) -> TensorDictBase:
5430+
if self.call_before_reset:
5431+
return tensordict_reset
5432+
return super()._reset(tensordict, tensordict_reset)
5433+
53665434
def call_tokenizer_fn(self, value: str | list[str]):
53675435
device = self.device
53685436
kwargs = {"add_special_tokens": self.add_special_tokens}
@@ -5372,19 +5440,25 @@ def call_tokenizer_fn(self, value: str | list[str]):
53725440
if isinstance(value, str):
53735441
out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0]
53745442
# TODO: incorporate attention mask
5375-
# attention_mask = torch.ones_like(out, dtype=torch.bool)
5443+
if self.return_attention_mask:
5444+
attention_mask = torch.ones_like(out, dtype=torch.int64)
53765445
else:
53775446
kwargs["padding"] = (
53785447
self.padding if self.max_length is None else "max_length"
53795448
)
5380-
# kwargs["return_attention_mask"] = False
5449+
kwargs["return_attention_mask"] = self.return_attention_mask
53815450
# kwargs["return_token_type_ids"] = False
53825451
out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs)
5383-
# attention_mask = out["attention_mask"]
5452+
if self.return_attention_mask:
5453+
attention_mask = out["attention_mask"]
53845454
out = out["input_ids"]
53855455

53865456
if device is not None and out.device != device:
53875457
out = out.to(device)
5458+
if self.return_attention_mask:
5459+
attention_mask = attention_mask.to(device)
5460+
if self.return_attention_mask:
5461+
return out, attention_mask
53885462
return out
53895463

53905464
def call_tokenizer_inv_fn(self, value: Tensor):
@@ -5396,81 +5470,110 @@ def call_tokenizer_inv_fn(self, value: Tensor):
53965470
out = self.tokenizer.batch_decode(
53975471
value, skip_special_tokens=self.skip_special_tokens
53985472
)
5473+
device = self._str_device
53995474
if isinstance(out, list):
5400-
return NonTensorStack(*out)
5401-
return NonTensorData(out)
5475+
result = NonTensorStack(*out)
5476+
if device:
5477+
result = result.to(device)
5478+
return result
5479+
return NonTensorData(out, device=device)
5480+
5481+
@property
5482+
def _str_device(self):
5483+
parent = self.parent
5484+
if parent is None:
5485+
return None
5486+
if self.in_keys:
5487+
in_key = self.in_keys[0]
5488+
elif self.in_keys_inv:
5489+
in_key = self.in_keys_inv[0]
5490+
else:
5491+
return None
5492+
if in_key in parent.observation_keys:
5493+
return parent.full_observation_spec[in_key].device
5494+
if in_key in parent.action_keys:
5495+
return parent.full_action_spec[in_key].device
5496+
if in_key in parent.state_keys:
5497+
return parent.full_state_spec[in_key].device
5498+
return None
54025499

54035500
def transform_input_spec(self, input_spec: Composite) -> Composite:
5404-
input_spec = super().transform_input_spec(input_spec)
54055501
# We need to cap the spec to generate valid random strings
5406-
for out_key in self.out_keys_inv:
5407-
if out_key in input_spec["full_state_spec"].keys(True, True):
5408-
new_shape = input_spec["full_state_spec"][out_key].shape
5409-
if self.max_length is None:
5410-
# Then we can't tell what the shape will be
5411-
new_shape = new_shape[:-1] + torch.Size((-1,))
5412-
input_spec["full_state_spec"][out_key] = Bounded(
5413-
0,
5414-
self.tokenizer.vocab_size,
5415-
shape=new_shape,
5416-
device=input_spec["full_state_spec"][out_key].device,
5417-
dtype=input_spec["full_state_spec"][out_key].dtype,
5418-
)
5419-
elif out_key in input_spec["full_action_spec"].keys(True, True):
5420-
new_shape = input_spec["full_action_spec"][out_key].shape
5421-
if self.max_length is None:
5422-
# Then we can't tell what the shape will be
5423-
new_shape = new_shape[:-1] + torch.Size((-1,))
5424-
input_spec["full_action_spec"][out_key] = Bounded(
5425-
0,
5426-
self.tokenizer.vocab_size,
5427-
shape=new_shape,
5428-
device=input_spec["full_action_spec"][out_key].device,
5429-
dtype=input_spec["full_action_spec"][out_key].dtype,
5502+
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
5503+
if in_key in input_spec["full_state_spec"].keys(True, True):
5504+
spec = input_spec["full_state_spec"]
5505+
elif in_key in input_spec["full_action_spec"].keys(False, True):
5506+
spec = input_spec["full_action_spec"]
5507+
else:
5508+
raise KeyError(
5509+
f"The input keys {in_key} wasn't found in the env input specs."
54305510
)
5511+
local_spec = spec.pop(in_key)
5512+
local_dtype = local_spec.dtype
5513+
if local_dtype is None or local_dtype.is_floating_point:
5514+
local_dtype = torch.int64
5515+
new_shape = spec.shape
5516+
if self.max_length is None:
5517+
# Then we can't tell what the shape will be
5518+
new_shape = new_shape + torch.Size((-1,))
5519+
else:
5520+
new_shape = new_shape + torch.Size((self.max_length,))
5521+
spec[out_key] = Bounded(
5522+
0,
5523+
self.tokenizer.vocab_size,
5524+
shape=new_shape,
5525+
device=local_spec.device,
5526+
dtype=local_dtype,
5527+
)
54315528
return input_spec
54325529

5433-
def transform_output_spec(self, output_spec: Composite) -> Composite:
5434-
output_spec = super().transform_output_spec(output_spec)
5435-
# We need to cap the spec to generate valid random strings
5436-
for out_key in self.out_keys:
5437-
if out_key in output_spec["full_observation_spec"].keys(True, True):
5438-
new_shape = output_spec["full_observation_spec"][out_key].shape
5439-
if self.max_length is None:
5440-
# Then we can't tell what the shape will be
5441-
new_shape = new_shape[:-1] + torch.Size((-1,))
5442-
output_spec["full_observation_spec"][out_key] = Bounded(
5443-
0,
5444-
self.tokenizer.vocab_size,
5445-
shape=new_shape,
5446-
device=output_spec["full_observation_spec"][out_key].device,
5447-
dtype=output_spec["full_observation_spec"][out_key].dtype,
5448-
)
5449-
elif out_key in output_spec["full_reward_spec"].keys(True, True):
5450-
new_shape = output_spec["full_reward_spec"][out_key].shape
5451-
if self.max_length is None:
5452-
# Then we can't tell what the shape will be
5453-
new_shape = new_shape[:-1] + torch.Size((-1,))
5454-
output_spec["full_reward_spec"][out_key] = Bounded(
5455-
0,
5456-
self.tokenizer.vocab_size,
5457-
shape=new_shape,
5458-
device=output_spec["full_reward_spec"][out_key].device,
5459-
dtype=output_spec["full_reward_spec"][out_key].dtype,
5460-
)
5461-
elif out_key in output_spec["full_done_spec"].keys(True, True):
5462-
new_shape = output_spec["full_done_spec"][out_key].shape
5463-
if self.max_length is None:
5464-
# Then we can't tell what the shape will be
5465-
new_shape = new_shape[:-1] + torch.Size((-1,))
5466-
output_spec["full_done_spec"][out_key] = Bounded(
5530+
transform_output_spec = Transform.transform_output_spec
5531+
transform_reward_spec = Transform.transform_reward_spec
5532+
transform_done_spec = Transform.transform_done_spec
5533+
5534+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
5535+
attention_mask_keys = set()
5536+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
5537+
new_shape = observation_spec.shape + torch.Size((-1,))
5538+
try:
5539+
in_spec = observation_spec[in_key]
5540+
obs_dtype = in_spec.dtype
5541+
device = in_spec.device
5542+
except KeyError:
5543+
# In some cases (eg, the tokenizer is applied during reset on data that
5544+
# originates from a dataloader) we don't have an in_spec
5545+
in_spec = None
5546+
obs_dtype = None
5547+
device = observation_spec.device
5548+
if obs_dtype is None or obs_dtype.is_floating_point:
5549+
obs_dtype = torch.int64
5550+
observation_spec[out_key] = Bounded(
5551+
0,
5552+
self.tokenizer.vocab_size,
5553+
shape=new_shape,
5554+
device=device,
5555+
dtype=obs_dtype,
5556+
)
5557+
if self.return_attention_mask:
5558+
attention_mask_key = _replace_last(out_key, "attention_mask")
5559+
if attention_mask_key in attention_mask_keys:
5560+
raise KeyError(
5561+
"Conflicting attention_mask keys. Make sure the token tensors are "
5562+
"nested at different places in the tensordict such that `(*root, 'attention_mask')` "
5563+
"entries are unique."
5564+
)
5565+
attention_mask_keys.add(attention_mask_key)
5566+
attention_dtype = obs_dtype
5567+
if attention_dtype is None or attention_dtype.is_floating_point:
5568+
attention_dtype = torch.int64
5569+
observation_spec[attention_mask_key] = Bounded(
54675570
0,
5468-
self.tokenizer.vocab_size,
5571+
2,
54695572
shape=new_shape,
5470-
device=output_spec["full_done_spec"][out_key].device,
5471-
dtype=output_spec["full_done_spec"][out_key].dtype,
5573+
device=device,
5574+
dtype=attention_dtype,
54725575
)
5473-
return output_spec
5576+
return observation_spec
54745577

54755578

54765579
class Stack(Transform):
@@ -6087,7 +6190,7 @@ def __init__(
60876190
kwargs = primers
60886191
if not isinstance(kwargs, Composite):
60896192
shape = kwargs.pop("shape", None)
6090-
device = kwargs.pop("device", None)
6193+
device = self.device
60916194
if "batch_size" in kwargs.keys():
60926195
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
60936196
else:
@@ -6160,7 +6263,7 @@ def reset_key(self, value):
61606263
@property
61616264
def device(self):
61626265
device = self._device
6163-
if device is None and self.parent is not None:
6266+
if device is None and hasattr(self, "parent") and self.parent is not None:
61646267
device = self.parent.device
61656268
self._device = device
61666269
return device

‎torchrl/modules/llm/vllm_policy.py

+57-31
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
from tensordict import (
1212
from_dataclass,
13+
lazy_stack,
14+
LazyStackedTensorDict,
1315
maybe_dense_stack,
1416
NestedKey,
1517
NonTensorData,
@@ -20,6 +22,7 @@
2022
TensorDictModule as Mod,
2123
TensorDictModuleBase,
2224
TensorDictSequential as Seq,
25+
WrapModule,
2326
)
2427
from tensordict.utils import _zip_strict
2528

@@ -61,6 +64,7 @@ def from_vllm(
6164
generate: bool = True,
6265
generate_kwargs: dict | None = None,
6366
tokenizer_kwargs: dict | None = None,
67+
pad_output: bool = True,
6468
) -> TensorDictModuleBase:
6569
"""Creates a TensorDictModule from a vLLM model.
6670
@@ -151,7 +155,7 @@ def from_vllm(
151155
out_keys=["tokens_in"],
152156
method_kwargs=tokenizer_kwargs,
153157
strict=True,
154-
inplace=False,
158+
inplace="empty",
155159
)
156160
else:
157161
module_dict["encode"] = Mod(
@@ -164,7 +168,7 @@ def from_vllm(
164168
in_keys=[text_key, "text_response"],
165169
out_keys=["tokens_in", "tokens_response"],
166170
strict=True,
167-
inplace=False,
171+
inplace="empty",
168172
)
169173

170174
def select(x, y):
@@ -196,7 +200,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
196200
("tokens_in", "attention_mask"),
197201
],
198202
strict=False,
199-
inplace=False,
203+
inplace="empty",
200204
)
201205
else:
202206
module_dict["move_inputs"] = Mod(
@@ -205,7 +209,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
205209
out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")],
206210
# It's ok if there's no mask
207211
strict=False,
208-
inplace=False,
212+
inplace="empty",
209213
)
210214

211215
def to_list(tokens, attention_mask):
@@ -240,11 +244,10 @@ def to_list(tokens, attention_mask):
240244
)
241245

242246
if generate_kwargs is None:
243-
generate_kwargs = {
244-
"detokenize": False,
245-
"prompt_logprobs": not generate,
246-
"logprobs": return_log_probs,
247-
}
247+
generate_kwargs = {}
248+
generate_kwargs.setdefault("detokenize", False)
249+
generate_kwargs.setdefault("prompt_logprobs", not generate)
250+
generate_kwargs.setdefault("logprobs", return_log_probs)
248251
if not generate:
249252
generate_kwargs["max_tokens"] = 1
250253
sampling_params = SamplingParams(**generate_kwargs)
@@ -261,13 +264,27 @@ def to_list(tokens, attention_mask):
261264
strict=True,
262265
)
263266

264-
def get_output_tokens_and_log_probs(td):
267+
padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0]
268+
269+
def get_output_tokens_and_log_probs(td, padding_value=padding_value):
265270
td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"])
271+
if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict):
272+
td = lazy_stack(list(td.unbind(0)))
266273
if generate:
267274
# When not generate, we don't want to overwrite this
268-
td["tokens_response"] = td["tokens_out"].outputs.token_ids
275+
tokens_response_td = td["tokens_out"].outputs._tensordict.select(
276+
"token_ids", "logprobs", strict=False
277+
)
278+
if pad_output:
279+
tokens_response_td = tokens_response_td.densify(
280+
layout=torch.strided
281+
).to_padded_tensor(padding=padding_value)
282+
tokens_response_td.rename_key_("token_ids", "tokens_response")
283+
# td["tokens_response"] = outputs.token_ids
269284
if return_log_probs:
270-
td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1)
285+
tokens_response_td.rename_key_("logprobs", "log_probs")
286+
# td["log_probs"] = outputs.logprobs.unsqueeze(-1)
287+
td.update(tokens_response_td)
271288
elif not generate:
272289
td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1)
273290
return td
@@ -296,32 +313,41 @@ def translate_lps(tokens_response, x):
296313
module_dict["to_source_device"] = _maybe_set_device
297314

298315
if generate:
299-
module_dict["format"] = Mod(
300-
lambda *x: x,
301-
in_keys=[
302-
"log_probs",
303-
"tokens_response",
304-
("tokens_in", "input_ids"),
305-
("tokens_in", "attention_mask"),
306-
"text_response",
307-
],
308-
out_keys=[
309-
"log_probs",
310-
"tokens_response",
311-
token_key,
312-
attention_mask_key,
313-
"text_response",
314-
],
315-
strict=False,
316-
inplace=False,
316+
in_keys = [
317+
"log_probs",
318+
"tokens_response",
319+
("tokens_in", "input_ids"),
320+
("tokens_in", "attention_mask"),
321+
"text_response",
322+
]
323+
out_keys = [
324+
"log_probs",
325+
"tokens_response",
326+
token_key,
327+
attention_mask_key,
328+
"text_response",
329+
]
330+
331+
def format_td(td):
332+
td = td.select(*in_keys, strict=False)
333+
td.rename_key_(("tokens_in", "input_ids"), token_key)
334+
td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key)
335+
del td["tokens_in"]
336+
return td
337+
338+
module_dict["format"] = WrapModule(
339+
format_td,
340+
in_keys=in_keys,
341+
out_keys=out_keys,
317342
)
343+
318344
else:
319345
module_dict["format"] = Mod(
320346
lambda *x: x,
321347
in_keys=["log_probs", "tokens_response"],
322348
out_keys=["log_probs", "tokens_response"],
323349
strict=False,
324-
inplace=False,
350+
inplace="empty",
325351
)
326352

327353
return Seq(module_dict, inplace=True)

0 commit comments

Comments
 (0)
Please sign in to comment.