@@ -333,14 +333,17 @@ def _fake_top_k_mask_logits(
333
333
334
334
@register_custom_op (
335
335
"flashinfer::chain_speculative_sampling" ,
336
- mutates_args = ("output_accepted_token_num" , "output_emitted_token_num" ),
336
+ mutates_args = (
337
+ "output_accepted_token_num" ,
338
+ "output_emitted_draft_token_num" ,
339
+ ),
337
340
)
338
341
def chain_speculative_sampling (
339
342
draft_probs : torch .Tensor ,
340
343
draft_token_ids : torch .Tensor ,
341
344
target_probs : torch .Tensor ,
342
345
output_accepted_token_num : torch .Tensor ,
343
- output_emitted_token_num : torch .Tensor ,
346
+ output_emitted_draft_token_num : torch .Tensor ,
344
347
deterministic : bool ,
345
348
generator : Optional [torch .Generator ],
346
349
) -> torch .Tensor :
@@ -349,7 +352,7 @@ def chain_speculative_sampling(
349
352
draft_token_ids = draft_token_ids .int ()
350
353
target_probs = target_probs .float ()
351
354
output_accepted_token_num = output_accepted_token_num .int ()
352
- output_emitted_token_num = output_emitted_token_num .int ()
355
+ output_emitted_draft_token_num = output_emitted_draft_token_num .int ()
353
356
b , n = draft_token_ids .shape
354
357
output_token_ids = torch .empty ((b , n + 1 ), dtype = torch .int32 , device = device )
355
358
module .chain_speculative_sampling .default (
@@ -358,7 +361,7 @@ def chain_speculative_sampling(
358
361
target_probs ,
359
362
output_token_ids ,
360
363
output_accepted_token_num ,
361
- output_emitted_token_num ,
364
+ output_emitted_draft_token_num ,
362
365
deterministic ,
363
366
generator ,
364
367
)
@@ -370,7 +373,7 @@ def _fake_chain_speculative_sampling(
370
373
draft_token_ids : torch .Tensor ,
371
374
target_probs : torch .Tensor ,
372
375
output_accepted_token_num : torch .Tensor ,
373
- output_emitted_token_num : torch .Tensor ,
376
+ output_emitted_draft_token_num : torch .Tensor ,
374
377
deterministic : bool ,
375
378
generator : Optional [torch .Generator ],
376
379
) -> torch .Tensor :
@@ -1130,7 +1133,7 @@ def chain_speculative_sampling(
1130
1133
draft_token_ids ,
1131
1134
target_probs ,
1132
1135
maybe_output_accepted_token_num : Optional [torch .Tensor ] = None ,
1133
- maybe_output_emitted_token_num : Optional [torch .Tensor ] = None ,
1136
+ maybe_output_emitted_draft_token_num : Optional [torch .Tensor ] = None ,
1134
1137
deterministic : bool = True ,
1135
1138
generator : Optional [torch .Generator ] = None ,
1136
1139
) -> torch .Tensor :
@@ -1158,8 +1161,10 @@ def chain_speculative_sampling(
1158
1161
It only evaluates the alignment of draft model and target model.
1159
1162
Shape: ``(batch_size)``
1160
1163
If specified, the number of accepted token number will be added to this tensor inplace. Default is ``None``.
1161
- maybe_output_emitted_token_num: Optional[torch.Tensor]
1162
- The number of tokens that are finally emitted/generated for each request.
1164
+ maybe_output_emitted_draft_token_num: Optional[torch.Tensor]
1165
+ The number of draft tokens that are finally emitted for each request. Does not include
1166
+ the bonus token. (Thus the total number of tokens sampled for a given request is
1167
+ output_emitted_draft_token_num + 1).
1163
1168
Shape: ``(batch_size)``
1164
1169
If specified, the number of emitted token number will be added to this tensor inplace. Default is ``None``.
1165
1170
deterministic: bool
@@ -1182,8 +1187,10 @@ def chain_speculative_sampling(
1182
1187
satisfy the probability requirement r < p/q.
1183
1188
It only evaluates the alignment of draft model and target model.
1184
1189
Shape: ``(batch_size)``
1185
- output_emitted_token_num: torch.Tensor
1186
- The number of tokens that are finally emitted/generated for each request.
1190
+ output_emitted_draft_token_num: torch.Tensor
1191
+ The number of draft tokens that are finally emitted for each request. Does not include
1192
+ the bonus token. (Thus the total number of tokens sampled for a given request is
1193
+ output_emitted_draft_token_num + 1).
1187
1194
Shape: ``(batch_size)``
1188
1195
1189
1196
Examples
@@ -1200,7 +1207,7 @@ def chain_speculative_sampling(
1200
1207
>>> # token 1 was sampled from draft model for the second token
1201
1208
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
1202
1209
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
1203
- >>> output_token_ids, output_accepted_token_num, output_accepted_token_num =\
1210
+ >>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\
1204
1211
... flashinfer.sampling.chain_speculative_sampling(
1205
1212
... draft_probs, draft_token_ids, target_probs)
1206
1213
>>> # the first token is accepted, the second token is rejected and sampled from the difference
@@ -1209,7 +1216,7 @@ def chain_speculative_sampling(
1209
1216
tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32)
1210
1217
>>> output_accepted_token_num
1211
1218
tensor([1], device='cuda:0')
1212
- >>> output_emitted_token_num
1219
+ >>> output_emitted_draft_token_num
1213
1220
tensor([1], device='cuda:0')
1214
1221
"""
1215
1222
b = draft_probs .size (0 )
@@ -1218,17 +1225,17 @@ def chain_speculative_sampling(
1218
1225
output_accepted_token_num = torch .zeros (b , dtype = torch .int32 , device = dev )
1219
1226
else :
1220
1227
output_accepted_token_num = maybe_output_accepted_token_num
1221
- if maybe_output_emitted_token_num is None :
1222
- output_emitted_token_num = torch .zeros (b , dtype = torch .int32 , device = dev )
1228
+ if maybe_output_emitted_draft_token_num is None :
1229
+ output_emitted_draft_token_num = torch .zeros (b , dtype = torch .int32 , device = dev )
1223
1230
else :
1224
- output_emitted_token_num = maybe_output_emitted_token_num
1231
+ output_emitted_draft_token_num = maybe_output_emitted_draft_token_num
1225
1232
output_token_ids = get_sampling_module ().chain_speculative_sampling (
1226
1233
draft_probs ,
1227
1234
draft_token_ids ,
1228
1235
target_probs ,
1229
1236
output_accepted_token_num ,
1230
- output_emitted_token_num ,
1237
+ output_emitted_draft_token_num ,
1231
1238
deterministic ,
1232
1239
generator ,
1233
1240
)
1234
- return output_token_ids , output_accepted_token_num , output_emitted_token_num
1241
+ return output_token_ids , output_accepted_token_num , output_emitted_draft_token_num
0 commit comments