Skip to content

Commit 4223720

Browse files
authored
[SLM] Deepseek Multi-GPU support (#2988)
This PR supports TP function of Deepseek Model
1 parent 17c4e08 commit 4223720

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

python/mlc_llm/model/deepseek/deepseek_model.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mlc_llm.nn import PagedKVCache, RopeMode
1616
from mlc_llm.nn.expert import MixtralExperts
1717
from mlc_llm.support import logging
18+
from mlc_llm.support import tensor_parallel as tp
1819
from mlc_llm.support.config import ConfigBase
1920
from mlc_llm.support.style import bold
2021

@@ -48,6 +49,7 @@ class DeepseekConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
4849
context_window_size: int = 0
4950
prefill_chunk_size: int = 0
5051
tensor_parallel_shards: int = 1
52+
head_dim: int = 0
5153
max_batch_size: int = 1
5254
num_experts_per_tok: int = 0
5355
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
@@ -70,6 +72,9 @@ def __post_init__(self):
7072
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
7173
"provided in `config.json`."
7274
)
75+
if self.head_dim == 0:
76+
self.head_dim = self.hidden_size // self.num_attention_heads
77+
assert self.head_dim * self.num_attention_heads == self.hidden_size
7378
if self.prefill_chunk_size == 0:
7479
logger.info(
7580
"%s defaults to %d",
@@ -85,7 +90,6 @@ def __post_init__(self):
8590
min(self.context_window_size, 2048),
8691
)
8792
self.prefill_chunk_size = min(self.context_window_size, 2048)
88-
assert self.tensor_parallel_shards == 1, "Deepseek currently does not support sharding."
8993

9094

9195
# pylint: disable=invalid-name,missing-docstring
@@ -96,17 +100,18 @@ def __init__(self, config: DeepseekConfig):
96100
super().__init__() # Make sure to call the parent class constructor
97101
self.hidden_size = config.hidden_size
98102
self.rope_theta = config.rope_theta
103+
self.tensor_parallel_shards = config.tensor_parallel_shards
99104
if config.num_attention_heads % config.tensor_parallel_shards != 0:
100105
raise ValueError(
101106
f"Cannot split {config.num_attention_heads} attention heads "
102107
f"evenly to {config.tensor_parallel_shards} GPUs."
103108
)
104109

105110
self.attention_bias = config.attention_bias
106-
self.num_heads = config.num_attention_heads
107-
self.num_key_value_heads = config.num_key_value_heads
111+
self.num_heads = config.num_attention_heads // self.tensor_parallel_shards
112+
self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards
108113
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
109-
self.head_dim = self.hidden_size // self.num_heads
114+
self.head_dim = config.head_dim
110115
self.max_position_embeddings = config.context_window_size
111116

112117
self.wqkv_pack = nn.Linear(
@@ -150,7 +155,7 @@ def __init__(self, config: DeepseekConfig, intermediate_size=None):
150155
)
151156
self.intermediate_size = (
152157
config.intermediate_size if intermediate_size is None else intermediate_size
153-
)
158+
) // config.tensor_parallel_shards
154159

155160
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
156161
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
@@ -168,16 +173,18 @@ def __init__(self, config: DeepseekConfig):
168173
self.num_experts_per_tok = config.num_experts_per_tok
169174
self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False)
170175
self.norm_topk_prob = config.norm_topk_prob
171-
self.moe_intermediate_size = config.moe_intermediate_size
176+
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
172177
self.moe_gate_up_proj = MixtralExperts(
173178
self.num_local_experts,
174179
in_features=config.hidden_size,
175180
out_features=2 * self.moe_intermediate_size,
181+
tensor_parallel_shards=config.tensor_parallel_shards,
176182
)
177183
self.moe_down_proj = MixtralExperts(
178184
self.num_local_experts,
179185
in_features=self.moe_intermediate_size,
180186
out_features=config.hidden_size,
187+
tensor_parallel_shards=config.tensor_parallel_shards,
181188
)
182189
self.dtype = "float32"
183190

@@ -254,15 +261,64 @@ def __init__(self, config: DeepseekConfig, layer_idx: int):
254261
config.hidden_size, -1, config.rms_norm_eps, bias=False
255262
)
256263

264+
def _set_tp():
265+
def _set(layer, hint):
266+
layer.attrs["shard_strategy"] = hint
267+
268+
hd = config.head_dim
269+
q = self.self_attn.num_heads * hd
270+
k = self.self_attn.num_key_value_heads * hd
271+
v = self.self_attn.num_key_value_heads * hd
272+
273+
if (
274+
config.n_routed_experts is not None
275+
and layer_idx >= config.first_k_dense_replace
276+
and layer_idx % config.moe_layer_freq == 0
277+
):
278+
i = self.mlp.moe_intermediate_size
279+
else:
280+
i = self.mlp.intermediate_size
281+
_set(
282+
self.self_attn.wqkv_pack.weight,
283+
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
284+
)
285+
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
286+
287+
if (
288+
config.n_routed_experts is not None
289+
and layer_idx >= config.first_k_dense_replace
290+
and layer_idx % config.moe_layer_freq == 0
291+
):
292+
_set(
293+
self.mlp.moe_gate_up_proj.weight,
294+
tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=1),
295+
)
296+
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=2))
297+
298+
else:
299+
_set(
300+
self.mlp.gate_up_proj.weight,
301+
tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0),
302+
)
303+
_set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=1))
304+
305+
self.tensor_parallel_shards = config.tensor_parallel_shards
306+
_set_tp()
307+
257308
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
258309
out = self.input_layernorm(hidden_states)
259310
out = self.self_attn(out, paged_kv_cache, layer_id)
260-
hidden_states = hidden_states + out
311+
hidden_states = self._apply_residual(hidden_states, residual=out)
261312
out = self.post_attention_layernorm(hidden_states)
262313
out = self.mlp(out) # type: ignore[operator]
263-
hidden_states = hidden_states + out
314+
hidden_states = self._apply_residual(hidden_states, residual=out)
264315
return hidden_states
265316

317+
def _apply_residual(self, out, residual):
318+
if self.tensor_parallel_shards > 1:
319+
return op.ccl_allreduce(out, "sum") + residual
320+
return out + residual
321+
266322

267323
class DeepseekModel(nn.Module):
268324
def __init__(self, config: DeepseekConfig):
@@ -293,7 +349,8 @@ def __init__(self, config: DeepseekConfig):
293349
self.hidden_size = config.hidden_size
294350
self.num_attention_heads = config.num_attention_heads
295351
self.num_key_value_heads = config.num_key_value_heads
296-
self.head_dim = config.hidden_size // config.num_attention_heads
352+
self.tensor_parallel_shards = config.tensor_parallel_shards
353+
self.head_dim = config.head_dim
297354
self.vocab_size = config.vocab_size
298355
self.rope_theta = config.rope_theta
299356
self.dtype = "float32"
@@ -320,6 +377,8 @@ def batch_forward(
320377
return logits
321378

322379
def embed(self, input_ids: Tensor):
380+
if self.tensor_parallel_shards > 1:
381+
input_ids = op.ccl_broadcast_from_worker0(input_ids)
323382
return self.model.embed_tokens(input_ids)
324383

325384
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
@@ -349,6 +408,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
349408
def batch_prefill(
350409
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
351410
):
411+
if self.tensor_parallel_shards > 1:
412+
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
352413
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
353414
return logits, paged_kv_cache
354415

@@ -375,8 +436,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
375436
page_size=page_size,
376437
support_sliding_window=support_sliding_window,
377438
num_hidden_layers=self.num_hidden_layers,
378-
num_attention_heads=self.num_attention_heads,
379-
num_key_value_heads=self.num_key_value_heads,
439+
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
440+
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
380441
head_dim=self.head_dim,
381442
rope_mode=RopeMode.NORMAL,
382443
rope_scale=1,

0 commit comments

Comments
 (0)