Skip to content

DeepSeek V2/V3 MLA implementation #12801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8c02442
Merged using squash to remove all noise commit messages
jukofyork Apr 7, 2025
ddab5e4
Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large
jukofyork Apr 7, 2025
7612566
Merge branch 'ggml-org:master' into mla--ready-for-review
jukofyork Apr 9, 2025
fed6600
Merge branch 'ggml-org:master' into mla--ready-for-review
jukofyork Apr 11, 2025
c449488
Merge branch 'ggml-org:master' into mla--ready-for-review
jukofyork Apr 12, 2025
2a4e1b2
Removed 3 conts (2x RoPE and 1x RMS-norm)
jukofyork Apr 12, 2025
77fe59b
Changed to use `<cmath>` instead of `<math.h>`
jukofyork Apr 12, 2025
e215323
Reverted removal of the 3 conts
jukofyork Apr 12, 2025
815f4f9
Used `reshape` in `llm_graph_context::build_attn_mha()`
jukofyork Apr 12, 2025
5778861
Use `k_pe = ggml_reshape`
jukofyork Apr 12, 2025
77ad5e4
Merge branch 'mla--ready-for-review' of https://github.com/jukofyork/…
jukofyork Apr 12, 2025
5d037ae
Removed the 3 conts again
jukofyork Apr 12, 2025
638b092
Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF
jukofyork Apr 12, 2025
a5df71e
Removed MQA optimisation from `build_attn_mha()` as no gains now
jukofyork Apr 13, 2025
925af99
Simplified `is_mla` branch in `llm_build_deepseek2()`
jukofyork Apr 13, 2025
36ce235
Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls
jukofyork Apr 13, 2025
a574278
Fixed call to `build_attn` in `llm_build_t5_enc`
jukofyork Apr 13, 2025
ed00f1e
Merge branch 'ggml-org:master' into mla--ready-for-review
jukofyork Apr 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4422,6 +4422,10 @@ def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):

# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
hparams = self.hparams

Expand All @@ -4430,8 +4434,13 @@ def set_gguf_parameters(self):
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["v_head_dim"])

# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])

self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
Expand Down Expand Up @@ -4500,6 +4509,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []

# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)

return [
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class Attention:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down Expand Up @@ -382,6 +384,8 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_K_B = auto()
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -590,6 +594,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -1517,6 +1523,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ def add_key_length(self, length: int) -> None:
def add_value_length(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)

def add_key_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)

def add_value_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)

def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)

Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,14 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
23 changes: 6 additions & 17 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
Expand Down Expand Up @@ -1103,6 +1105,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
Expand Down Expand Up @@ -1563,23 +1567,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
4 changes: 4 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ enum llm_kv {
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
Expand Down Expand Up @@ -306,6 +308,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
Expand Down
15 changes: 12 additions & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cstring>
#include <stdexcept>
#include <cinttypes>
#include <cmath>

//
// llama_context
Expand Down Expand Up @@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift(
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;

const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;

Expand All @@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift(
const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type;

// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;

Comment on lines +485 to +488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be absorbed in the cparams.yarn_attn_factor and deduplicated from here and llm_build_deepseek2?

ggml_tensor * tmp;

if (ggml_is_quantized(cur->type)) {
Expand All @@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(

tmp = ggml_rope_ext_inplace(ctx0, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);

tmp = ggml_cpy(ctx0, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
tmp = ggml_rope_ext_inplace(ctx0, cur,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
}

return tmp;
Expand Down Expand Up @@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}

if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
params.flash_attn = false;
}

Comment on lines +2281 to +2285
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a way to enable the FA path to be compatible with MLA. I will fix this in a follow-up PR https://github.com/ggml-org/llama.cpp/tree/gg/mla. For now, keep it like this.

if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
Expand Down
19 changes: 14 additions & 5 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla,
bool v_trans,
float kq_scale) const {
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
Expand All @@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;

const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
Expand Down Expand Up @@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(

ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);

// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
if (v_mla) {
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
}

Comment on lines +1272 to +1276
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decompression can be performed after the build_attn_mha. This way, the FA path will not be coupled to MLA and we can extend it in the future.

I have prepare a fix in 6dfbed0, so for now let's merge it like this and I will resolve this afterwards.

ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);

cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
Expand Down Expand Up @@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
GGML_UNUSED(n_tokens);
Expand All @@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);

cb(cur, "kqv_out", il);

Expand Down Expand Up @@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
Expand Down Expand Up @@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
Expand All @@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);

cb(cur, "kqv_out", il);

Expand Down Expand Up @@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling(

ggml_build_forward_expand(gf, cur);
}

10 changes: 7 additions & 3 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,12 @@ struct llm_graph_context {

ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
bool v_trans,
float kq_scale) const;

Expand All @@ -524,6 +525,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;

Expand All @@ -538,6 +540,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;

Expand All @@ -552,6 +555,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;

Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ struct llama_hparams {
uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0;

// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;

// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init(

recurrent = llama_model_is_recurrent(&model);
v_trans = !recurrent && !cparams.flash_attn;
can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
can_shift = !recurrent;

LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
Expand Down
Loading
Loading