Skip to content

Commit 2628a69

Browse files
[V1] Support Deepseek MTP (#18435)
Signed-off-by: Rui Qiao <[email protected]> Signed-off-by: YaoJiayi <[email protected]> Co-authored-by: Rui Qiao <[email protected]>
1 parent 371f7e4 commit 2628a69

File tree

6 files changed

+120
-66
lines changed

6 files changed

+120
-66
lines changed

vllm/config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,7 +2255,7 @@ def __post_init__(self):
22552255

22562256

22572257
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
2258-
"draft_model"]
2258+
"draft_model", "deepseek_mtp"]
22592259
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
22602260
"typical_acceptance_sampler"]
22612261

@@ -2519,6 +2519,15 @@ def __post_init__(self):
25192519
elif (self.draft_model_config.hf_config.model_type ==
25202520
"mlp_speculator"):
25212521
self.method = "mlp_speculator"
2522+
elif (self.draft_model_config.hf_config.model_type ==
2523+
"deepseek_mtp"):
2524+
self.method = "deepseek_mtp"
2525+
if self.num_speculative_tokens > 1:
2526+
logger.warning(
2527+
"All Deepseek MTP models only have " \
2528+
"one layer. Might need some code changes " \
2529+
"to support multiple layers."
2530+
)
25222531
else:
25232532
self.method = "draft_model"
25242533

@@ -2738,7 +2747,7 @@ def num_lookahead_slots(self) -> int:
27382747
return self.num_speculative_tokens
27392748

27402749
def use_eagle(self) -> bool:
2741-
return self.method in ("eagle", "eagle3")
2750+
return self.method in ("eagle", "eagle3", "deepseek_mtp")
27422751

27432752
def __repr__(self) -> str:
27442753
method = self.method

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13381338
is_ngram_enabled = True
13391339
elif speculative_method == "medusa":
13401340
is_medusa_enabled = True
1341-
elif speculative_method in ("eagle", "eagle3"):
1341+
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
13421342
is_eagle_enabled = True
13431343
else:
13441344
speculative_model = self.speculative_config.get("model")

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .deepseek_v2 import (DeepseekV2DecoderLayer,
2121
get_spec_layer_idx_from_weight_name)
22+
from .interfaces import SupportsPP
2223
from .utils import maybe_prefix
2324

2425

@@ -145,7 +146,7 @@ def compute_logits(
145146
return logits
146147

147148

148-
class DeepSeekMTP(nn.Module):
149+
class DeepSeekMTP(nn.Module, SupportsPP):
149150

150151
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
151152
super().__init__()

vllm/v1/spec_decode/eagle.py

Lines changed: 65 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
1212
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
13-
from vllm.triton_utils import tl, triton
14-
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
13+
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
14+
FlashAttentionMetadata)
1515
from vllm.v1.sample.metadata import SamplingMetadata
16+
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
1617

1718
logger = init_logger(__name__)
1819

@@ -25,12 +26,15 @@ def __init__(
2526
self,
2627
vllm_config: VllmConfig,
2728
device: torch.device,
29+
runner=None,
2830
):
2931
self.vllm_config = vllm_config
3032
self.speculative_config = vllm_config.speculative_config
3133
self.draft_model_config = self.speculative_config.draft_model_config
3234
self.method = self.speculative_config.method
3335

36+
self.runner = runner
37+
3438
self.dtype = vllm_config.model_config.dtype
3539
self.max_model_len = vllm_config.model_config.max_model_len
3640
self.block_size = vllm_config.cache_config.block_size
@@ -106,24 +110,46 @@ def propose(
106110
# FA requires seq_len to have dtype int32.
107111
seq_lens = (target_positions[last_token_indices] + 1).int()
108112

109-
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
110-
max_seq_len = seq_lens.max().item()
111-
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
112-
attn_metadata = FlashAttentionMetadata(
113-
num_actual_tokens=num_tokens,
114-
max_query_len=max_num_tokens,
115-
query_start_loc=cu_num_tokens,
116-
max_seq_len=max_seq_len,
117-
seq_lens=seq_lens,
118-
block_table=block_table,
119-
slot_mapping=target_slot_mapping,
120-
# TODO(woosuk): Support cascade attention.
121-
use_cascade=False,
122-
common_prefix_len=0,
123-
cu_prefix_query_lens=None,
124-
prefix_kv_lens=None,
125-
suffix_kv_lens=None,
126-
)
113+
if self.method in ["eagle", "eagle3"]:
114+
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
115+
max_seq_len = seq_lens.max().item()
116+
max_num_tokens = (cu_num_tokens[1:] -
117+
cu_num_tokens[:-1]).max().item()
118+
attn_metadata = FlashAttentionMetadata(
119+
num_actual_tokens=num_tokens,
120+
max_query_len=max_num_tokens,
121+
query_start_loc=cu_num_tokens,
122+
max_seq_len=max_seq_len,
123+
seq_lens=seq_lens,
124+
block_table=block_table,
125+
slot_mapping=target_slot_mapping,
126+
# TODO(woosuk): Support cascade attention.
127+
use_cascade=False,
128+
common_prefix_len=0,
129+
cu_prefix_query_lens=None,
130+
prefix_kv_lens=None,
131+
suffix_kv_lens=None,
132+
)
133+
elif self.method == "deepseek_mtp":
134+
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
135+
max_query_len = query_lens.max().item()
136+
137+
common_attn_metadata = CommonAttentionMetadata(
138+
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
139+
140+
assert self.runner is not None
141+
142+
# FIXME: need to consider multiple kv_cache_groups
143+
attn_metadata = self.runner.attn_metadata_builder.build(
144+
num_reqs=batch_size,
145+
num_actual_tokens=num_tokens,
146+
max_query_len=max_query_len,
147+
common_prefix_len=0,
148+
common_attn_metadata=common_attn_metadata,
149+
)
150+
else:
151+
raise ValueError(f"Unsupported method: {self.method}")
152+
127153
if self.use_cuda_graph and \
128154
num_tokens <= self.cudagraph_batch_sizes[-1]:
129155
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -136,11 +162,15 @@ def propose(
136162
with set_forward_context(attn_metadata,
137163
self.vllm_config,
138164
num_tokens=num_input_tokens):
139-
last_hidden_states, hidden_states = self.model(
140-
input_ids=self.input_ids[:num_input_tokens],
141-
positions=self.positions[:num_input_tokens],
142-
hidden_states=self.hidden_states[:num_input_tokens],
165+
ret_hidden_states = self.model(
166+
self.input_ids[:num_input_tokens],
167+
self.positions[:num_input_tokens],
168+
self.hidden_states[:num_input_tokens],
143169
)
170+
if self.method == "deepseek_mtp":
171+
last_hidden_states = ret_hidden_states
172+
else:
173+
last_hidden_states, hidden_states = ret_hidden_states
144174
sample_hidden_states = last_hidden_states[last_token_indices]
145175
logits = self.model.compute_logits(sample_hidden_states, None)
146176
draft_token_ids = logits.argmax(dim=-1)
@@ -150,6 +180,10 @@ def propose(
150180
# [batch_size, 1]
151181
return draft_token_ids.view(-1, 1)
152182

183+
# TODO: Currently, MTP module released by deepseek only has
184+
# one layer. Adapt this code to support multiple layers once
185+
# there's a multi-layer MTP module.
186+
153187
# Generate the remaining draft tokens.
154188
draft_token_ids_list = [draft_token_ids]
155189

@@ -215,9 +249,9 @@ def propose(
215249
self.vllm_config,
216250
num_tokens=input_batch_size):
217251
last_hidden_states, hidden_states = self.model(
218-
input_ids=self.input_ids[:input_batch_size],
219-
positions=self.positions[:input_batch_size],
220-
hidden_states=self.hidden_states[:input_batch_size],
252+
self.input_ids[:input_batch_size],
253+
self.positions[:input_batch_size],
254+
self.hidden_states[:input_batch_size],
221255
)
222256
hidden_states = hidden_states[:batch_size]
223257
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@@ -268,7 +302,7 @@ def prepare_inputs(
268302

269303
batch_size = num_rejected_tokens.shape[0]
270304
BLOCK_SIZE = 1024
271-
prepare_input_kernel[(batch_size, )](
305+
prepare_eagle_input_kernel[(batch_size, )](
272306
token_indices,
273307
cu_target_query_lens,
274308
cu_num_tokens,
@@ -320,9 +354,9 @@ def dummy_run(
320354
with set_forward_context(None, self.vllm_config,
321355
num_tokens=num_tokens):
322356
self.model(
323-
input_ids=self.input_ids[:num_tokens],
324-
positions=self.positions[:num_tokens],
325-
hidden_states=self.hidden_states[:num_tokens],
357+
self.input_ids[:num_tokens],
358+
self.positions[:num_tokens],
359+
self.hidden_states[:num_tokens],
326360
)
327361

328362

@@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
367401
next_token_ids,
368402
)
369403
return next_token_ids, probs
370-
371-
372-
@triton.jit
373-
def prepare_input_kernel(
374-
out_ptr,
375-
cu_query_lens_ptr,
376-
cu_num_tokens_ptr,
377-
BLOCK_SIZE: tl.constexpr,
378-
):
379-
pid = tl.program_id(0)
380-
381-
# [start_pos, end_pos)
382-
start_pos = tl.load(cu_num_tokens_ptr + pid)
383-
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
384-
num_tokens = end_pos - start_pos
385-
386-
index_start = tl.load(cu_query_lens_ptr + pid)
387-
388-
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
389-
for i in tl.range(num_blocks):
390-
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
391-
tl.store(
392-
out_ptr + start_pos + offset,
393-
index_start + offset,
394-
mask=offset < num_tokens,
395-
)

vllm/v1/spec_decode/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
from vllm.triton_utils import tl, triton
23
from vllm.v1.worker.gpu_input_batch import InputBatch
34

45

@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
1617
return False
1718

1819
return True
20+
21+
22+
@triton.jit
23+
def prepare_eagle_input_kernel(
24+
out_ptr,
25+
cu_query_lens_ptr,
26+
cu_num_tokens_ptr,
27+
BLOCK_SIZE: tl.constexpr,
28+
):
29+
pid = tl.program_id(0)
30+
31+
# [start_pos, end_pos)
32+
start_pos = tl.load(cu_num_tokens_ptr + pid)
33+
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
34+
num_tokens = end_pos - start_pos
35+
36+
index_start = tl.load(cu_query_lens_ptr + pid)
37+
38+
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
39+
for i in tl.range(num_blocks):
40+
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
41+
tl.store(
42+
out_ptr + start_pos + offset,
43+
index_start + offset,
44+
mask=offset < num_tokens,
45+
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,16 @@ def __init__(
151151
self.use_aux_hidden_state_outputs = False
152152
if self.speculative_config:
153153
self.use_spec_decode = True
154+
155+
# NOTE(Jiayi): currently we put the entire draft model on
156+
# the last PP rank. This is not ideal if there are many
157+
# layers in the draft model.
154158
if get_pp_group().is_last_rank:
155159
if self.speculative_config.method == "ngram":
156160
self.drafter = NgramProposer(self.vllm_config)
157161
elif self.speculative_config.use_eagle():
158-
self.drafter = EagleProposer(self.vllm_config,
159-
self.device) # type: ignore
162+
self.drafter = EagleProposer(self.vllm_config, self.device,
163+
self) # type: ignore
160164
if self.speculative_config.method == "eagle3":
161165
self.use_aux_hidden_state_outputs = True
162166
elif self.speculative_config.method == "medusa":
@@ -1361,6 +1365,12 @@ def execute_model(
13611365
device=self.device)
13621366
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
13631367

1368+
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
1369+
if hasattr(eagle_attn_metadata, "block_table"):
1370+
block_table = eagle_attn_metadata.block_table
1371+
else:
1372+
block_table = None
1373+
13641374
if spec_decode_metadata is None:
13651375
# input_ids can be None for multimodal models.
13661376
target_token_ids = self.input_ids[:num_scheduled_tokens]
@@ -1406,7 +1416,7 @@ def execute_model(
14061416
target_slot_mapping=target_slot_mapping,
14071417
next_token_ids=next_token_ids,
14081418
cu_num_tokens=cu_num_tokens,
1409-
block_table=eagle_attn_metadata.block_table,
1419+
block_table=block_table,
14101420
sampling_metadata=sampling_metadata,
14111421
)
14121422
spec_token_ids = draft_token_ids.tolist()
@@ -1723,8 +1733,7 @@ def _dummy_run(
17231733
else:
17241734
hidden_states = outputs
17251735

1726-
if self.use_spec_decode and \
1727-
self.speculative_config.method in ('eagle', 'eagle3'):
1736+
if self.use_spec_decode and self.speculative_config.use_eagle():
17281737
assert isinstance(self.drafter, EagleProposer)
17291738
self.drafter.dummy_run(num_tokens)
17301739

0 commit comments

Comments
 (0)