10
10
from vllm .logger import init_logger
11
11
from vllm .model_executor .model_loader import get_model
12
12
from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
13
- from vllm .triton_utils import tl , triton
14
- from vllm . v1 . attention . backends . flash_attn import FlashAttentionMetadata
13
+ from vllm .v1 . attention . backends . flash_attn import ( CommonAttentionMetadata ,
14
+ FlashAttentionMetadata )
15
15
from vllm .v1 .sample .metadata import SamplingMetadata
16
+ from vllm .v1 .spec_decode .utils import prepare_eagle_input_kernel
16
17
17
18
logger = init_logger (__name__ )
18
19
@@ -25,12 +26,15 @@ def __init__(
25
26
self ,
26
27
vllm_config : VllmConfig ,
27
28
device : torch .device ,
29
+ runner = None ,
28
30
):
29
31
self .vllm_config = vllm_config
30
32
self .speculative_config = vllm_config .speculative_config
31
33
self .draft_model_config = self .speculative_config .draft_model_config
32
34
self .method = self .speculative_config .method
33
35
36
+ self .runner = runner
37
+
34
38
self .dtype = vllm_config .model_config .dtype
35
39
self .max_model_len = vllm_config .model_config .max_model_len
36
40
self .block_size = vllm_config .cache_config .block_size
@@ -106,24 +110,46 @@ def propose(
106
110
# FA requires seq_len to have dtype int32.
107
111
seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
108
112
109
- # FIXME(woosuk): The below two ops cause synchronization. Optimize.
110
- max_seq_len = seq_lens .max ().item ()
111
- max_num_tokens = (cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]).max ().item ()
112
- attn_metadata = FlashAttentionMetadata (
113
- num_actual_tokens = num_tokens ,
114
- max_query_len = max_num_tokens ,
115
- query_start_loc = cu_num_tokens ,
116
- max_seq_len = max_seq_len ,
117
- seq_lens = seq_lens ,
118
- block_table = block_table ,
119
- slot_mapping = target_slot_mapping ,
120
- # TODO(woosuk): Support cascade attention.
121
- use_cascade = False ,
122
- common_prefix_len = 0 ,
123
- cu_prefix_query_lens = None ,
124
- prefix_kv_lens = None ,
125
- suffix_kv_lens = None ,
126
- )
113
+ if self .method in ["eagle" , "eagle3" ]:
114
+ # FIXME(woosuk): The below two ops cause synchronization. Optimize.
115
+ max_seq_len = seq_lens .max ().item ()
116
+ max_num_tokens = (cu_num_tokens [1 :] -
117
+ cu_num_tokens [:- 1 ]).max ().item ()
118
+ attn_metadata = FlashAttentionMetadata (
119
+ num_actual_tokens = num_tokens ,
120
+ max_query_len = max_num_tokens ,
121
+ query_start_loc = cu_num_tokens ,
122
+ max_seq_len = max_seq_len ,
123
+ seq_lens = seq_lens ,
124
+ block_table = block_table ,
125
+ slot_mapping = target_slot_mapping ,
126
+ # TODO(woosuk): Support cascade attention.
127
+ use_cascade = False ,
128
+ common_prefix_len = 0 ,
129
+ cu_prefix_query_lens = None ,
130
+ prefix_kv_lens = None ,
131
+ suffix_kv_lens = None ,
132
+ )
133
+ elif self .method == "deepseek_mtp" :
134
+ query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
135
+ max_query_len = query_lens .max ().item ()
136
+
137
+ common_attn_metadata = CommonAttentionMetadata (
138
+ query_start_loc = cu_num_tokens , seq_lens = seq_lens )
139
+
140
+ assert self .runner is not None
141
+
142
+ # FIXME: need to consider multiple kv_cache_groups
143
+ attn_metadata = self .runner .attn_metadata_builder .build (
144
+ num_reqs = batch_size ,
145
+ num_actual_tokens = num_tokens ,
146
+ max_query_len = max_query_len ,
147
+ common_prefix_len = 0 ,
148
+ common_attn_metadata = common_attn_metadata ,
149
+ )
150
+ else :
151
+ raise ValueError (f"Unsupported method: { self .method } " )
152
+
127
153
if self .use_cuda_graph and \
128
154
num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
129
155
num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
@@ -136,11 +162,15 @@ def propose(
136
162
with set_forward_context (attn_metadata ,
137
163
self .vllm_config ,
138
164
num_tokens = num_input_tokens ):
139
- last_hidden_states , hidden_states = self .model (
140
- input_ids = self .input_ids [:num_input_tokens ],
141
- positions = self .positions [:num_input_tokens ],
142
- hidden_states = self .hidden_states [:num_input_tokens ],
165
+ ret_hidden_states = self .model (
166
+ self .input_ids [:num_input_tokens ],
167
+ self .positions [:num_input_tokens ],
168
+ self .hidden_states [:num_input_tokens ],
143
169
)
170
+ if self .method == "deepseek_mtp" :
171
+ last_hidden_states = ret_hidden_states
172
+ else :
173
+ last_hidden_states , hidden_states = ret_hidden_states
144
174
sample_hidden_states = last_hidden_states [last_token_indices ]
145
175
logits = self .model .compute_logits (sample_hidden_states , None )
146
176
draft_token_ids = logits .argmax (dim = - 1 )
@@ -150,6 +180,10 @@ def propose(
150
180
# [batch_size, 1]
151
181
return draft_token_ids .view (- 1 , 1 )
152
182
183
+ # TODO: Currently, MTP module released by deepseek only has
184
+ # one layer. Adapt this code to support multiple layers once
185
+ # there's a multi-layer MTP module.
186
+
153
187
# Generate the remaining draft tokens.
154
188
draft_token_ids_list = [draft_token_ids ]
155
189
@@ -215,9 +249,9 @@ def propose(
215
249
self .vllm_config ,
216
250
num_tokens = input_batch_size ):
217
251
last_hidden_states , hidden_states = self .model (
218
- input_ids = self .input_ids [:input_batch_size ],
219
- positions = self .positions [:input_batch_size ],
220
- hidden_states = self .hidden_states [:input_batch_size ],
252
+ self .input_ids [:input_batch_size ],
253
+ self .positions [:input_batch_size ],
254
+ self .hidden_states [:input_batch_size ],
221
255
)
222
256
hidden_states = hidden_states [:batch_size ]
223
257
logits = self .model .compute_logits (last_hidden_states [:batch_size ],
@@ -268,7 +302,7 @@ def prepare_inputs(
268
302
269
303
batch_size = num_rejected_tokens .shape [0 ]
270
304
BLOCK_SIZE = 1024
271
- prepare_input_kernel [(batch_size , )](
305
+ prepare_eagle_input_kernel [(batch_size , )](
272
306
token_indices ,
273
307
cu_target_query_lens ,
274
308
cu_num_tokens ,
@@ -320,9 +354,9 @@ def dummy_run(
320
354
with set_forward_context (None , self .vllm_config ,
321
355
num_tokens = num_tokens ):
322
356
self .model (
323
- input_ids = self .input_ids [:num_tokens ],
324
- positions = self .positions [:num_tokens ],
325
- hidden_states = self .hidden_states [:num_tokens ],
357
+ self .input_ids [:num_tokens ],
358
+ self .positions [:num_tokens ],
359
+ self .hidden_states [:num_tokens ],
326
360
)
327
361
328
362
@@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
367
401
next_token_ids ,
368
402
)
369
403
return next_token_ids , probs
370
-
371
-
372
- @triton .jit
373
- def prepare_input_kernel (
374
- out_ptr ,
375
- cu_query_lens_ptr ,
376
- cu_num_tokens_ptr ,
377
- BLOCK_SIZE : tl .constexpr ,
378
- ):
379
- pid = tl .program_id (0 )
380
-
381
- # [start_pos, end_pos)
382
- start_pos = tl .load (cu_num_tokens_ptr + pid )
383
- end_pos = tl .load (cu_num_tokens_ptr + pid + 1 )
384
- num_tokens = end_pos - start_pos
385
-
386
- index_start = tl .load (cu_query_lens_ptr + pid )
387
-
388
- num_blocks = tl .cdiv (num_tokens , BLOCK_SIZE )
389
- for i in tl .range (num_blocks ):
390
- offset = i * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
391
- tl .store (
392
- out_ptr + start_pos + offset ,
393
- index_start + offset ,
394
- mask = offset < num_tokens ,
395
- )
0 commit comments