Skip to content

Commit

Permalink
[SLM] Deepseek Multi-GPU support (#2988)
Browse files Browse the repository at this point in the history
This PR supports TP function of Deepseek Model
  • Loading branch information
tlopex authored Oct 22, 2024
1 parent 17c4e08 commit 4223720
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions python/mlc_llm/model/deepseek/deepseek_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand Down Expand Up @@ -48,6 +49,7 @@ class DeepseekConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
context_window_size: int = 0
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
head_dim: int = 0
max_batch_size: int = 1
num_experts_per_tok: int = 0
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
Expand All @@ -70,6 +72,9 @@ def __post_init__(self):
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.head_dim == 0:
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.head_dim * self.num_attention_heads == self.hidden_size
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %d",
Expand All @@ -85,7 +90,6 @@ def __post_init__(self):
min(self.context_window_size, 2048),
)
self.prefill_chunk_size = min(self.context_window_size, 2048)
assert self.tensor_parallel_shards == 1, "Deepseek currently does not support sharding."


# pylint: disable=invalid-name,missing-docstring
Expand All @@ -96,17 +100,18 @@ def __init__(self, config: DeepseekConfig):
super().__init__() # Make sure to call the parent class constructor
self.hidden_size = config.hidden_size
self.rope_theta = config.rope_theta
self.tensor_parallel_shards = config.tensor_parallel_shards
if config.num_attention_heads % config.tensor_parallel_shards != 0:
raise ValueError(
f"Cannot split {config.num_attention_heads} attention heads "
f"evenly to {config.tensor_parallel_shards} GPUs."
)

self.attention_bias = config.attention_bias
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_heads = config.num_attention_heads // self.tensor_parallel_shards
self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = config.head_dim
self.max_position_embeddings = config.context_window_size

self.wqkv_pack = nn.Linear(
Expand Down Expand Up @@ -150,7 +155,7 @@ def __init__(self, config: DeepseekConfig, intermediate_size=None):
)
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
) // config.tensor_parallel_shards

self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
Expand All @@ -168,16 +173,18 @@ def __init__(self, config: DeepseekConfig):
self.num_experts_per_tok = config.num_experts_per_tok
self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False)
self.norm_topk_prob = config.norm_topk_prob
self.moe_intermediate_size = config.moe_intermediate_size
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
self.moe_gate_up_proj = MixtralExperts(
self.num_local_experts,
in_features=config.hidden_size,
out_features=2 * self.moe_intermediate_size,
tensor_parallel_shards=config.tensor_parallel_shards,
)
self.moe_down_proj = MixtralExperts(
self.num_local_experts,
in_features=self.moe_intermediate_size,
out_features=config.hidden_size,
tensor_parallel_shards=config.tensor_parallel_shards,
)
self.dtype = "float32"

Expand Down Expand Up @@ -254,15 +261,64 @@ def __init__(self, config: DeepseekConfig, layer_idx: int):
config.hidden_size, -1, config.rms_norm_eps, bias=False
)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.self_attn.num_heads * hd
k = self.self_attn.num_key_value_heads * hd
v = self.self_attn.num_key_value_heads * hd

if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
i = self.mlp.moe_intermediate_size
else:
i = self.mlp.intermediate_size
_set(
self.self_attn.wqkv_pack.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))

if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
_set(
self.mlp.moe_gate_up_proj.weight,
tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=1),
)
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=2))

else:
_set(
self.mlp.gate_up_proj.weight,
tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0),
)
_set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
out = self.self_attn(out, paged_kv_cache, layer_id)
hidden_states = hidden_states + out
hidden_states = self._apply_residual(hidden_states, residual=out)
out = self.post_attention_layernorm(hidden_states)
out = self.mlp(out) # type: ignore[operator]
hidden_states = hidden_states + out
hidden_states = self._apply_residual(hidden_states, residual=out)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class DeepseekModel(nn.Module):
def __init__(self, config: DeepseekConfig):
Expand Down Expand Up @@ -293,7 +349,8 @@ def __init__(self, config: DeepseekConfig):
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.tensor_parallel_shards = config.tensor_parallel_shards
self.head_dim = config.head_dim
self.vocab_size = config.vocab_size
self.rope_theta = config.rope_theta
self.dtype = "float32"
Expand All @@ -320,6 +377,8 @@ def batch_forward(
return logits

def embed(self, input_ids: Tensor):
if self.tensor_parallel_shards > 1:
input_ids = op.ccl_broadcast_from_worker0(input_ids)
return self.model.embed_tokens(input_ids)

def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
Expand Down Expand Up @@ -349,6 +408,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def batch_prefill(
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
):
if self.tensor_parallel_shards > 1:
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
return logits, paged_kv_cache

Expand All @@ -375,8 +436,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
page_size=page_size,
support_sliding_window=support_sliding_window,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=self.head_dim,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
Expand Down

0 comments on commit 4223720

Please sign in to comment.