Skip to content

Commit 86da6b8

Browse files
jon-chuangyzh119
andauthored
misc: Rename output_emitted_token_num -> output_emitted_draft_token_num (flashinfer-ai#977)
With the old naming, it is easy to get confused with off by one error, with the new naming it is clear. See example special handling required here due to unclear API: https://github.com/vllm-project/vllm/blob/280d074103160d042059dc60c28898fd9fb56568/vllm/model_executor/layers/rejection_sampler.py#L139 Co-authored-by: Zihao Ye <[email protected]>
1 parent bb028cc commit 86da6b8

File tree

5 files changed

+34
-27
lines changed

5 files changed

+34
-27
lines changed

csrc/flashinfer_ops.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
209209
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
210210
at::Tensor target_probs, at::Tensor output_token_ids,
211211
at::Tensor output_accepted_token_num,
212-
at::Tensor output_emitted_token_num, bool deterministic,
212+
at::Tensor output_emitted_draft_token_num, bool deterministic,
213213
std::optional<at::Generator> gen);
214214

215215
//========== Torch Library ==========

csrc/flashinfer_sampling_ops.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
5252
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
5353
at::Tensor target_probs, at::Tensor output_token_ids,
5454
at::Tensor output_accepted_token_num,
55-
at::Tensor output_emitted_token_num, bool deterministic,
55+
at::Tensor output_emitted_draft_token_num, bool deterministic,
5656
std::optional<at::Generator> gen);
5757

5858
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {

csrc/sampling.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
186186
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
187187
at::Tensor target_probs, at::Tensor output_token_ids,
188188
at::Tensor output_accepted_token_num,
189-
at::Tensor output_emitted_token_num, bool deterministic,
189+
at::Tensor output_emitted_draft_token_num, bool deterministic,
190190
std::optional<at::Generator> gen_) {
191191
CHECK_INPUT(draft_probs);
192192
CHECK_INPUT(draft_token_ids);
@@ -205,7 +205,7 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
205205
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
206206
CHECK_EQ(vocab_size, target_probs.size(2));
207207
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
208-
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
208+
CHECK_EQ(batch_size, output_emitted_draft_token_num.size(0));
209209
uint64_t philox_seed, philox_offset;
210210
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
211211
gen_, at::cuda::detail::getDefaultCUDAGenerator());
@@ -221,8 +221,8 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
221221
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
222222
static_cast<float*>(target_probs.data_ptr()), static_cast<int*>(output_token_ids.data_ptr()),
223223
static_cast<int*>(output_accepted_token_num.data_ptr()),
224-
static_cast<int*>(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens,
225-
vocab_size, deterministic, philox_seed, philox_offset, stream);
224+
static_cast<int*>(output_emitted_draft_token_num.data_ptr()), batch_size,
225+
num_speculate_tokens, vocab_size, deterministic, philox_seed, philox_offset, stream);
226226

227227
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
228228
std::string(cudaGetErrorString(status)));

flashinfer/sampling.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,17 @@ def _fake_top_k_mask_logits(
333333

334334
@register_custom_op(
335335
"flashinfer::chain_speculative_sampling",
336-
mutates_args=("output_accepted_token_num", "output_emitted_token_num"),
336+
mutates_args=(
337+
"output_accepted_token_num",
338+
"output_emitted_draft_token_num",
339+
),
337340
)
338341
def chain_speculative_sampling(
339342
draft_probs: torch.Tensor,
340343
draft_token_ids: torch.Tensor,
341344
target_probs: torch.Tensor,
342345
output_accepted_token_num: torch.Tensor,
343-
output_emitted_token_num: torch.Tensor,
346+
output_emitted_draft_token_num: torch.Tensor,
344347
deterministic: bool,
345348
generator: Optional[torch.Generator],
346349
) -> torch.Tensor:
@@ -349,7 +352,7 @@ def chain_speculative_sampling(
349352
draft_token_ids = draft_token_ids.int()
350353
target_probs = target_probs.float()
351354
output_accepted_token_num = output_accepted_token_num.int()
352-
output_emitted_token_num = output_emitted_token_num.int()
355+
output_emitted_draft_token_num = output_emitted_draft_token_num.int()
353356
b, n = draft_token_ids.shape
354357
output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device)
355358
module.chain_speculative_sampling.default(
@@ -358,7 +361,7 @@ def chain_speculative_sampling(
358361
target_probs,
359362
output_token_ids,
360363
output_accepted_token_num,
361-
output_emitted_token_num,
364+
output_emitted_draft_token_num,
362365
deterministic,
363366
generator,
364367
)
@@ -370,7 +373,7 @@ def _fake_chain_speculative_sampling(
370373
draft_token_ids: torch.Tensor,
371374
target_probs: torch.Tensor,
372375
output_accepted_token_num: torch.Tensor,
373-
output_emitted_token_num: torch.Tensor,
376+
output_emitted_draft_token_num: torch.Tensor,
374377
deterministic: bool,
375378
generator: Optional[torch.Generator],
376379
) -> torch.Tensor:
@@ -1130,7 +1133,7 @@ def chain_speculative_sampling(
11301133
draft_token_ids,
11311134
target_probs,
11321135
maybe_output_accepted_token_num: Optional[torch.Tensor] = None,
1133-
maybe_output_emitted_token_num: Optional[torch.Tensor] = None,
1136+
maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None,
11341137
deterministic: bool = True,
11351138
generator: Optional[torch.Generator] = None,
11361139
) -> torch.Tensor:
@@ -1158,8 +1161,10 @@ def chain_speculative_sampling(
11581161
It only evaluates the alignment of draft model and target model.
11591162
Shape: ``(batch_size)``
11601163
If specified, the number of accepted token number will be added to this tensor inplace. Default is ``None``.
1161-
maybe_output_emitted_token_num: Optional[torch.Tensor]
1162-
The number of tokens that are finally emitted/generated for each request.
1164+
maybe_output_emitted_draft_token_num: Optional[torch.Tensor]
1165+
The number of draft tokens that are finally emitted for each request. Does not include
1166+
the bonus token. (Thus the total number of tokens sampled for a given request is
1167+
output_emitted_draft_token_num + 1).
11631168
Shape: ``(batch_size)``
11641169
If specified, the number of emitted token number will be added to this tensor inplace. Default is ``None``.
11651170
deterministic: bool
@@ -1182,8 +1187,10 @@ def chain_speculative_sampling(
11821187
satisfy the probability requirement r < p/q.
11831188
It only evaluates the alignment of draft model and target model.
11841189
Shape: ``(batch_size)``
1185-
output_emitted_token_num: torch.Tensor
1186-
The number of tokens that are finally emitted/generated for each request.
1190+
output_emitted_draft_token_num: torch.Tensor
1191+
The number of draft tokens that are finally emitted for each request. Does not include
1192+
the bonus token. (Thus the total number of tokens sampled for a given request is
1193+
output_emitted_draft_token_num + 1).
11871194
Shape: ``(batch_size)``
11881195
11891196
Examples
@@ -1200,7 +1207,7 @@ def chain_speculative_sampling(
12001207
>>> # token 1 was sampled from draft model for the second token
12011208
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
12021209
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
1203-
>>> output_token_ids, output_accepted_token_num, output_accepted_token_num =\
1210+
>>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\
12041211
... flashinfer.sampling.chain_speculative_sampling(
12051212
... draft_probs, draft_token_ids, target_probs)
12061213
>>> # the first token is accepted, the second token is rejected and sampled from the difference
@@ -1209,7 +1216,7 @@ def chain_speculative_sampling(
12091216
tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32)
12101217
>>> output_accepted_token_num
12111218
tensor([1], device='cuda:0')
1212-
>>> output_emitted_token_num
1219+
>>> output_emitted_draft_token_num
12131220
tensor([1], device='cuda:0')
12141221
"""
12151222
b = draft_probs.size(0)
@@ -1218,17 +1225,17 @@ def chain_speculative_sampling(
12181225
output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
12191226
else:
12201227
output_accepted_token_num = maybe_output_accepted_token_num
1221-
if maybe_output_emitted_token_num is None:
1222-
output_emitted_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
1228+
if maybe_output_emitted_draft_token_num is None:
1229+
output_emitted_draft_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
12231230
else:
1224-
output_emitted_token_num = maybe_output_emitted_token_num
1231+
output_emitted_draft_token_num = maybe_output_emitted_draft_token_num
12251232
output_token_ids = get_sampling_module().chain_speculative_sampling(
12261233
draft_probs,
12271234
draft_token_ids,
12281235
target_probs,
12291236
output_accepted_token_num,
1230-
output_emitted_token_num,
1237+
output_emitted_draft_token_num,
12311238
deterministic,
12321239
generator,
12331240
)
1234-
return output_token_ids, output_accepted_token_num, output_emitted_token_num
1241+
return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num

include/flashinfer/sampling.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
13831383
__global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
13841384
DType* target_probs, IdType* output_token_ids,
13851385
IdType* output_accepted_token_num,
1386-
IdType* output_emitted_token_num,
1386+
IdType* output_emitted_draft_token_num,
13871387
uint32_t num_speculative_tokens, uint32_t d,
13881388
uint64_t philox_seed, uint64_t philox_offset) {
13891389
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
@@ -1427,7 +1427,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
14271427

14281428
if (tx == 0) {
14291429
output_accepted_token_num[row_idx] += accepted_token_num;
1430-
output_emitted_token_num[row_idx] += emitted_token_num;
1430+
output_emitted_draft_token_num[row_idx] += emitted_token_num;
14311431
}
14321432

14331433
// sample from relu(target_probs - draft_probs)
@@ -1517,7 +1517,7 @@ template <typename DType, typename IdType>
15171517
cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
15181518
DType* target_probs, IdType* output_token_ids,
15191519
IdType* output_accepted_token_num,
1520-
IdType* output_emitted_token_num, uint32_t batch_size,
1520+
IdType* output_emitted_draft_token_num, uint32_t batch_size,
15211521
uint32_t num_speculative_tokens, uint32_t d,
15221522
bool deterministic, uint64_t philox_seed,
15231523
uint64_t philox_offset, cudaStream_t stream = 0) {
@@ -1532,7 +1532,7 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids
15321532
&target_probs,
15331533
&output_token_ids,
15341534
&output_accepted_token_num,
1535-
&output_emitted_token_num,
1535+
&output_emitted_draft_token_num,
15361536
&num_speculative_tokens,
15371537
&d,
15381538
&philox_seed,

0 commit comments

Comments
 (0)