10
10
import torch
11
11
from tensordict import (
12
12
from_dataclass ,
13
+ lazy_stack ,
14
+ LazyStackedTensorDict ,
13
15
maybe_dense_stack ,
14
16
NestedKey ,
15
17
NonTensorData ,
20
22
TensorDictModule as Mod ,
21
23
TensorDictModuleBase ,
22
24
TensorDictSequential as Seq ,
25
+ WrapModule ,
23
26
)
24
27
from tensordict .utils import _zip_strict
25
28
@@ -61,6 +64,7 @@ def from_vllm(
61
64
generate : bool = True ,
62
65
generate_kwargs : dict | None = None ,
63
66
tokenizer_kwargs : dict | None = None ,
67
+ pad_output : bool = True ,
64
68
) -> TensorDictModuleBase :
65
69
"""Creates a TensorDictModule from a vLLM model.
66
70
@@ -151,7 +155,7 @@ def from_vllm(
151
155
out_keys = ["tokens_in" ],
152
156
method_kwargs = tokenizer_kwargs ,
153
157
strict = True ,
154
- inplace = False ,
158
+ inplace = "empty" ,
155
159
)
156
160
else :
157
161
module_dict ["encode" ] = Mod (
@@ -164,7 +168,7 @@ def from_vllm(
164
168
in_keys = [text_key , "text_response" ],
165
169
out_keys = ["tokens_in" , "tokens_response" ],
166
170
strict = True ,
167
- inplace = False ,
171
+ inplace = "empty" ,
168
172
)
169
173
170
174
def select (x , y ):
@@ -196,7 +200,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
196
200
("tokens_in" , "attention_mask" ),
197
201
],
198
202
strict = False ,
199
- inplace = False ,
203
+ inplace = "empty" ,
200
204
)
201
205
else :
202
206
module_dict ["move_inputs" ] = Mod (
@@ -205,7 +209,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
205
209
out_keys = [("tokens_in" , "input_ids" ), ("tokens_in" , "attention_mask" )],
206
210
# It's ok if there's no mask
207
211
strict = False ,
208
- inplace = False ,
212
+ inplace = "empty" ,
209
213
)
210
214
211
215
def to_list (tokens , attention_mask ):
@@ -261,13 +265,27 @@ def to_list(tokens, attention_mask):
261
265
strict = True ,
262
266
)
263
267
264
- def get_output_tokens_and_log_probs (td ):
268
+ padding_value = tokenizer (tokenizer .pad_token )["input_ids" ][0 ]
269
+
270
+ def get_output_tokens_and_log_probs (td , padding_value = padding_value ):
265
271
td ["tokens_out" ] = _RequestOutput_tc .from_request_output (td ["tokens_out" ])
272
+ if pad_output and td .ndim and not isinstance (td , LazyStackedTensorDict ):
273
+ td = lazy_stack (list (td .unbind (0 )))
266
274
if generate :
267
275
# When not generate, we don't want to overwrite this
268
- td ["tokens_response" ] = td ["tokens_out" ].outputs .token_ids
276
+ tokens_response_td = td ["tokens_out" ].outputs ._tensordict .select (
277
+ "token_ids" , "logprobs" , strict = False
278
+ )
279
+ if pad_output :
280
+ tokens_response_td = tokens_response_td .densify (
281
+ layout = torch .strided
282
+ ).to_padded_tensor (padding = padding_value )
283
+ tokens_response_td .rename_key_ ("token_ids" , "tokens_response" )
284
+ # td["tokens_response"] = outputs.token_ids
269
285
if return_log_probs :
270
- td ["log_probs" ] = td ["tokens_out" ].outputs .logprobs .unsqueeze (- 1 )
286
+ tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
287
+ # td["log_probs"] = outputs.logprobs.unsqueeze(-1)
288
+ td .update (tokens_response_td )
271
289
elif not generate :
272
290
td ["prompt_logprobs" ] = td ["tokens_out" ].prompt_logprobs .unsqueeze (- 1 )
273
291
return td
@@ -296,32 +314,40 @@ def translate_lps(tokens_response, x):
296
314
module_dict ["to_source_device" ] = _maybe_set_device
297
315
298
316
if generate :
299
- module_dict ["format" ] = Mod (
300
- lambda * x : x ,
301
- in_keys = [
302
- "log_probs" ,
303
- "tokens_response" ,
304
- ("tokens_in" , "input_ids" ),
305
- ("tokens_in" , "attention_mask" ),
306
- "text_response" ,
307
- ],
308
- out_keys = [
309
- "log_probs" ,
310
- "tokens_response" ,
311
- token_key ,
312
- attention_mask_key ,
313
- "text_response" ,
314
- ],
315
- strict = False ,
316
- inplace = False ,
317
+ in_keys = [
318
+ "log_probs" ,
319
+ "tokens_response" ,
320
+ ("tokens_in" , "input_ids" ),
321
+ ("tokens_in" , "attention_mask" ),
322
+ "text_response" ,
323
+ ]
324
+ out_keys = [
325
+ "log_probs" ,
326
+ "tokens_response" ,
327
+ token_key ,
328
+ attention_mask_key ,
329
+ "text_response" ,
330
+ ]
331
+
332
+ def format_td (td ):
333
+ td = td .select (* in_keys , strict = False )
334
+ td .rename_key_ (("tokens_in" , "input_ids" ), token_key )
335
+ td .rename_key_ (("tokens_in" , "attention_mask" ), attention_mask_key )
336
+ return td
337
+
338
+ module_dict ["format" ] = WrapModule (
339
+ format_td ,
340
+ in_keys = in_keys ,
341
+ out_keys = out_keys ,
317
342
)
343
+
318
344
else :
319
345
module_dict ["format" ] = Mod (
320
346
lambda * x : x ,
321
347
in_keys = ["log_probs" , "tokens_response" ],
322
348
out_keys = ["log_probs" , "tokens_response" ],
323
349
strict = False ,
324
- inplace = False ,
350
+ inplace = "empty" ,
325
351
)
326
352
327
353
return Seq (module_dict , inplace = True )
0 commit comments