15
15
from mlc_llm .nn import PagedKVCache , RopeMode
16
16
from mlc_llm .nn .expert import MixtralExperts
17
17
from mlc_llm .support import logging
18
+ from mlc_llm .support import tensor_parallel as tp
18
19
from mlc_llm .support .config import ConfigBase
19
20
from mlc_llm .support .style import bold
20
21
@@ -48,6 +49,7 @@ class DeepseekConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
48
49
context_window_size : int = 0
49
50
prefill_chunk_size : int = 0
50
51
tensor_parallel_shards : int = 1
52
+ head_dim : int = 0
51
53
max_batch_size : int = 1
52
54
num_experts_per_tok : int = 0
53
55
kwargs : Dict [str , Any ] = dataclasses .field (default_factory = dict )
@@ -70,6 +72,9 @@ def __post_init__(self):
70
72
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
71
73
"provided in `config.json`."
72
74
)
75
+ if self .head_dim == 0 :
76
+ self .head_dim = self .hidden_size // self .num_attention_heads
77
+ assert self .head_dim * self .num_attention_heads == self .hidden_size
73
78
if self .prefill_chunk_size == 0 :
74
79
logger .info (
75
80
"%s defaults to %d" ,
@@ -85,7 +90,6 @@ def __post_init__(self):
85
90
min (self .context_window_size , 2048 ),
86
91
)
87
92
self .prefill_chunk_size = min (self .context_window_size , 2048 )
88
- assert self .tensor_parallel_shards == 1 , "Deepseek currently does not support sharding."
89
93
90
94
91
95
# pylint: disable=invalid-name,missing-docstring
@@ -96,17 +100,18 @@ def __init__(self, config: DeepseekConfig):
96
100
super ().__init__ () # Make sure to call the parent class constructor
97
101
self .hidden_size = config .hidden_size
98
102
self .rope_theta = config .rope_theta
103
+ self .tensor_parallel_shards = config .tensor_parallel_shards
99
104
if config .num_attention_heads % config .tensor_parallel_shards != 0 :
100
105
raise ValueError (
101
106
f"Cannot split { config .num_attention_heads } attention heads "
102
107
f"evenly to { config .tensor_parallel_shards } GPUs."
103
108
)
104
109
105
110
self .attention_bias = config .attention_bias
106
- self .num_heads = config .num_attention_heads
107
- self .num_key_value_heads = config .num_key_value_heads
111
+ self .num_heads = config .num_attention_heads // self . tensor_parallel_shards
112
+ self .num_key_value_heads = config .num_key_value_heads // self . tensor_parallel_shards
108
113
self .num_key_value_groups = self .num_heads // self .num_key_value_heads
109
- self .head_dim = self . hidden_size // self . num_heads
114
+ self .head_dim = config . head_dim
110
115
self .max_position_embeddings = config .context_window_size
111
116
112
117
self .wqkv_pack = nn .Linear (
@@ -150,7 +155,7 @@ def __init__(self, config: DeepseekConfig, intermediate_size=None):
150
155
)
151
156
self .intermediate_size = (
152
157
config .intermediate_size if intermediate_size is None else intermediate_size
153
- )
158
+ ) // config . tensor_parallel_shards
154
159
155
160
self .gate_up_proj = nn .Linear (self .hidden_size , 2 * self .intermediate_size , bias = False )
156
161
self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias = False )
@@ -168,16 +173,18 @@ def __init__(self, config: DeepseekConfig):
168
173
self .num_experts_per_tok = config .num_experts_per_tok
169
174
self .gate = nn .Linear (config .hidden_size , config .n_routed_experts , bias = False )
170
175
self .norm_topk_prob = config .norm_topk_prob
171
- self .moe_intermediate_size = config .moe_intermediate_size
176
+ self .moe_intermediate_size = config .moe_intermediate_size // config . tensor_parallel_shards
172
177
self .moe_gate_up_proj = MixtralExperts (
173
178
self .num_local_experts ,
174
179
in_features = config .hidden_size ,
175
180
out_features = 2 * self .moe_intermediate_size ,
181
+ tensor_parallel_shards = config .tensor_parallel_shards ,
176
182
)
177
183
self .moe_down_proj = MixtralExperts (
178
184
self .num_local_experts ,
179
185
in_features = self .moe_intermediate_size ,
180
186
out_features = config .hidden_size ,
187
+ tensor_parallel_shards = config .tensor_parallel_shards ,
181
188
)
182
189
self .dtype = "float32"
183
190
@@ -254,15 +261,64 @@ def __init__(self, config: DeepseekConfig, layer_idx: int):
254
261
config .hidden_size , - 1 , config .rms_norm_eps , bias = False
255
262
)
256
263
264
+ def _set_tp ():
265
+ def _set (layer , hint ):
266
+ layer .attrs ["shard_strategy" ] = hint
267
+
268
+ hd = config .head_dim
269
+ q = self .self_attn .num_heads * hd
270
+ k = self .self_attn .num_key_value_heads * hd
271
+ v = self .self_attn .num_key_value_heads * hd
272
+
273
+ if (
274
+ config .n_routed_experts is not None
275
+ and layer_idx >= config .first_k_dense_replace
276
+ and layer_idx % config .moe_layer_freq == 0
277
+ ):
278
+ i = self .mlp .moe_intermediate_size
279
+ else :
280
+ i = self .mlp .intermediate_size
281
+ _set (
282
+ self .self_attn .wqkv_pack .weight ,
283
+ tp .ShardSingleDim ("_shard_qkv_weight" , dim = 0 , segs = [q , k , v ]),
284
+ )
285
+ _set (self .self_attn .o_proj .weight , tp .ShardSingleDim ("_shard_o" , dim = 1 ))
286
+
287
+ if (
288
+ config .n_routed_experts is not None
289
+ and layer_idx >= config .first_k_dense_replace
290
+ and layer_idx % config .moe_layer_freq == 0
291
+ ):
292
+ _set (
293
+ self .mlp .moe_gate_up_proj .weight ,
294
+ tp .ShardSingleDim ("_shard_mlp_up" , segs = [i , i ], dim = 1 ),
295
+ )
296
+ _set (self .mlp .moe_down_proj .weight , tp .ShardSingleDim ("_shard_mlp_down" , dim = 2 ))
297
+
298
+ else :
299
+ _set (
300
+ self .mlp .gate_up_proj .weight ,
301
+ tp .ShardSingleDim ("_shard_mlp_up" , segs = [i , i ], dim = 0 ),
302
+ )
303
+ _set (self .mlp .down_proj .weight , tp .ShardSingleDim ("_shard_mlp_down" , dim = 1 ))
304
+
305
+ self .tensor_parallel_shards = config .tensor_parallel_shards
306
+ _set_tp ()
307
+
257
308
def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
258
309
out = self .input_layernorm (hidden_states )
259
310
out = self .self_attn (out , paged_kv_cache , layer_id )
260
- hidden_states = hidden_states + out
311
+ hidden_states = self . _apply_residual ( hidden_states , residual = out )
261
312
out = self .post_attention_layernorm (hidden_states )
262
313
out = self .mlp (out ) # type: ignore[operator]
263
- hidden_states = hidden_states + out
314
+ hidden_states = self . _apply_residual ( hidden_states , residual = out )
264
315
return hidden_states
265
316
317
+ def _apply_residual (self , out , residual ):
318
+ if self .tensor_parallel_shards > 1 :
319
+ return op .ccl_allreduce (out , "sum" ) + residual
320
+ return out + residual
321
+
266
322
267
323
class DeepseekModel (nn .Module ):
268
324
def __init__ (self , config : DeepseekConfig ):
@@ -293,7 +349,8 @@ def __init__(self, config: DeepseekConfig):
293
349
self .hidden_size = config .hidden_size
294
350
self .num_attention_heads = config .num_attention_heads
295
351
self .num_key_value_heads = config .num_key_value_heads
296
- self .head_dim = config .hidden_size // config .num_attention_heads
352
+ self .tensor_parallel_shards = config .tensor_parallel_shards
353
+ self .head_dim = config .head_dim
297
354
self .vocab_size = config .vocab_size
298
355
self .rope_theta = config .rope_theta
299
356
self .dtype = "float32"
@@ -320,6 +377,8 @@ def batch_forward(
320
377
return logits
321
378
322
379
def embed (self , input_ids : Tensor ):
380
+ if self .tensor_parallel_shards > 1 :
381
+ input_ids = op .ccl_broadcast_from_worker0 (input_ids )
323
382
return self .model .embed_tokens (input_ids )
324
383
325
384
def prefill (self , input_embed : Tensor , paged_kv_cache : PagedKVCache ):
@@ -349,6 +408,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
349
408
def batch_prefill (
350
409
self , input_embeds : Tensor , logit_positions : Tensor , paged_kv_cache : PagedKVCache
351
410
):
411
+ if self .tensor_parallel_shards > 1 :
412
+ logit_positions = op .ccl_broadcast_from_worker0 (logit_positions )
352
413
logits = self .batch_forward (input_embeds , paged_kv_cache , logit_positions )
353
414
return logits , paged_kv_cache
354
415
@@ -375,8 +436,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
375
436
page_size = page_size ,
376
437
support_sliding_window = support_sliding_window ,
377
438
num_hidden_layers = self .num_hidden_layers ,
378
- num_attention_heads = self .num_attention_heads ,
379
- num_key_value_heads = self .num_key_value_heads ,
439
+ num_attention_heads = self .num_attention_heads // self . tensor_parallel_shards ,
440
+ num_key_value_heads = self .num_key_value_heads // self . tensor_parallel_shards ,
380
441
head_dim = self .head_dim ,
381
442
rope_mode = RopeMode .NORMAL ,
382
443
rope_scale = 1 ,
0 commit comments