Skip to content

Commit 1da5d20

Browse files
committed
[BugFix] Better handling of batches in vllm wrapper
ghstack-source-id: 73d41d803125647fb2902f31a97a443a5c607112 Pull Request resolved: pytorch/rl#2853
1 parent 60b7df0 commit 1da5d20

File tree

2 files changed

+59
-29
lines changed

2 files changed

+59
-29
lines changed

torchrl/envs/transforms/llm.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
from tensordict import (
14-
maybe_dense_stack,
14+
lazy_stack,
1515
NestedKey,
1616
TensorDict,
1717
TensorDictBase,
@@ -386,7 +386,7 @@ def __init__(
386386
self.endless_dataloader = self._endless_iter(self.dataloader)
387387

388388
if stack_method is None:
389-
stack_method = maybe_dense_stack
389+
stack_method = lazy_stack
390390
elif stack_method == "as_nested_tensor":
391391
stack_method = as_nested_tensor
392392
elif stack_method == "as_padded_tensor":
@@ -434,10 +434,14 @@ def _endless_iter(self, obj):
434434
while True:
435435
yield from obj
436436

437+
# def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
438+
# td = super()._reset_env_preprocess(tensordict)
439+
# return lazy_stack(list(td.unbind(0)))
440+
#
437441
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
438442
"""Loads a single element from the dataloader, or alternatively from the buffer.
439443
440-
If `reset` is passed, the one element per reset will be loaded.
444+
If `reset` is passed, then one element per reset will be loaded.
441445
"""
442446
if reset is not None:
443447
if not reset.any():

torchrl/modules/llm/vllm_policy.py

+52-26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
from tensordict import (
1212
from_dataclass,
13+
lazy_stack,
14+
LazyStackedTensorDict,
1315
maybe_dense_stack,
1416
NestedKey,
1517
NonTensorData,
@@ -20,6 +22,7 @@
2022
TensorDictModule as Mod,
2123
TensorDictModuleBase,
2224
TensorDictSequential as Seq,
25+
WrapModule,
2326
)
2427
from tensordict.utils import _zip_strict
2528

@@ -61,6 +64,7 @@ def from_vllm(
6164
generate: bool = True,
6265
generate_kwargs: dict | None = None,
6366
tokenizer_kwargs: dict | None = None,
67+
pad_output: bool = True,
6468
) -> TensorDictModuleBase:
6569
"""Creates a TensorDictModule from a vLLM model.
6670
@@ -151,7 +155,7 @@ def from_vllm(
151155
out_keys=["tokens_in"],
152156
method_kwargs=tokenizer_kwargs,
153157
strict=True,
154-
inplace=False,
158+
inplace="empty",
155159
)
156160
else:
157161
module_dict["encode"] = Mod(
@@ -164,7 +168,7 @@ def from_vllm(
164168
in_keys=[text_key, "text_response"],
165169
out_keys=["tokens_in", "tokens_response"],
166170
strict=True,
167-
inplace=False,
171+
inplace="empty",
168172
)
169173

170174
def select(x, y):
@@ -196,7 +200,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
196200
("tokens_in", "attention_mask"),
197201
],
198202
strict=False,
199-
inplace=False,
203+
inplace="empty",
200204
)
201205
else:
202206
module_dict["move_inputs"] = Mod(
@@ -205,7 +209,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
205209
out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")],
206210
# It's ok if there's no mask
207211
strict=False,
208-
inplace=False,
212+
inplace="empty",
209213
)
210214

211215
def to_list(tokens, attention_mask):
@@ -261,13 +265,27 @@ def to_list(tokens, attention_mask):
261265
strict=True,
262266
)
263267

264-
def get_output_tokens_and_log_probs(td):
268+
padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0]
269+
270+
def get_output_tokens_and_log_probs(td, padding_value=padding_value):
265271
td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"])
272+
if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict):
273+
td = lazy_stack(list(td.unbind(0)))
266274
if generate:
267275
# When not generate, we don't want to overwrite this
268-
td["tokens_response"] = td["tokens_out"].outputs.token_ids
276+
tokens_response_td = td["tokens_out"].outputs._tensordict.select(
277+
"token_ids", "logprobs", strict=False
278+
)
279+
if pad_output:
280+
tokens_response_td = tokens_response_td.densify(
281+
layout=torch.strided
282+
).to_padded_tensor(padding=padding_value)
283+
tokens_response_td.rename_key_("token_ids", "tokens_response")
284+
# td["tokens_response"] = outputs.token_ids
269285
if return_log_probs:
270-
td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1)
286+
tokens_response_td.rename_key_("logprobs", "log_probs")
287+
# td["log_probs"] = outputs.logprobs.unsqueeze(-1)
288+
td.update(tokens_response_td)
271289
elif not generate:
272290
td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1)
273291
return td
@@ -296,32 +314,40 @@ def translate_lps(tokens_response, x):
296314
module_dict["to_source_device"] = _maybe_set_device
297315

298316
if generate:
299-
module_dict["format"] = Mod(
300-
lambda *x: x,
301-
in_keys=[
302-
"log_probs",
303-
"tokens_response",
304-
("tokens_in", "input_ids"),
305-
("tokens_in", "attention_mask"),
306-
"text_response",
307-
],
308-
out_keys=[
309-
"log_probs",
310-
"tokens_response",
311-
token_key,
312-
attention_mask_key,
313-
"text_response",
314-
],
315-
strict=False,
316-
inplace=False,
317+
in_keys = [
318+
"log_probs",
319+
"tokens_response",
320+
("tokens_in", "input_ids"),
321+
("tokens_in", "attention_mask"),
322+
"text_response",
323+
]
324+
out_keys = [
325+
"log_probs",
326+
"tokens_response",
327+
token_key,
328+
attention_mask_key,
329+
"text_response",
330+
]
331+
332+
def format_td(td):
333+
td = td.select(*in_keys, strict=False)
334+
td.rename_key_(("tokens_in", "input_ids"), token_key)
335+
td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key)
336+
return td
337+
338+
module_dict["format"] = WrapModule(
339+
format_td,
340+
in_keys=in_keys,
341+
out_keys=out_keys,
317342
)
343+
318344
else:
319345
module_dict["format"] = Mod(
320346
lambda *x: x,
321347
in_keys=["log_probs", "tokens_response"],
322348
out_keys=["log_probs", "tokens_response"],
323349
strict=False,
324-
inplace=False,
350+
inplace="empty",
325351
)
326352

327353
return Seq(module_dict, inplace=True)

0 commit comments

Comments
 (0)