-
Notifications
You must be signed in to change notification settings - Fork 11.8k
DeepSeek V2/V3 MLA implementation #12801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8c02442
ddab5e4
7612566
fed6600
c449488
2a4e1b2
77fe59b
e215323
815f4f9
5778861
77ad5e4
5d037ae
638b092
a5df71e
925af99
36ce235
a574278
ed00f1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
#include <cstring> | ||
#include <stdexcept> | ||
#include <cinttypes> | ||
#include <cmath> | ||
|
||
// | ||
// llama_context | ||
|
@@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift( | |
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; | ||
|
||
const auto & yarn_ext_factor = cparams.yarn_ext_factor; | ||
const auto & yarn_attn_factor = cparams.yarn_attn_factor; | ||
const auto & yarn_beta_fast = cparams.yarn_beta_fast; | ||
const auto & yarn_beta_slow = cparams.yarn_beta_slow; | ||
|
||
|
@@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift( | |
const auto & n_rot = hparams.n_rot; | ||
const auto & rope_type = hparams.rope_type; | ||
|
||
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. | ||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. | ||
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; | ||
|
||
ggml_tensor * tmp; | ||
|
||
if (ggml_is_quantized(cur->type)) { | ||
|
@@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift( | |
|
||
tmp = ggml_rope_ext_inplace(ctx0, tmp, | ||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, | ||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); | ||
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); | ||
|
||
tmp = ggml_cpy(ctx0, tmp, cur); | ||
} else { | ||
// we rotate only the first n_rot dimensions | ||
tmp = ggml_rope_ext_inplace(ctx0, cur, | ||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, | ||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); | ||
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); | ||
} | ||
|
||
return tmp; | ||
|
@@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model( | |
params.flash_attn = false; | ||
} | ||
|
||
if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) { | ||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__); | ||
params.flash_attn = false; | ||
} | ||
|
||
Comment on lines
+2281
to
+2285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a way to enable the FA path to be compatible with MLA. I will fix this in a follow-up PR https://github.com/ggml-org/llama.cpp/tree/gg/mla. For now, keep it like this. |
||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) { | ||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); | ||
return nullptr; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( | |
ggml_tensor * v, | ||
ggml_tensor * kq_b, | ||
ggml_tensor * kq_mask, | ||
ggml_tensor * v_mla, | ||
bool v_trans, | ||
float kq_scale) const { | ||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); | ||
|
@@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( | |
//const auto & n_embd_head_k = hparams.n_embd_head_k; | ||
//const auto & n_embd_head_v = hparams.n_embd_head_v; | ||
|
||
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; | ||
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla | ||
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1]; | ||
|
||
const auto n_tokens = q->ne[1]; | ||
const auto n_head = q->ne[2]; | ||
|
@@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( | |
|
||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); | ||
|
||
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA | ||
if (v_mla) { | ||
kqv = ggml_mul_mat(ctx0, v_mla, kqv); | ||
} | ||
|
||
Comment on lines
+1272
to
+1276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This decompression can be performed after the I have prepare a fix in 6dfbed0, so for now let's merge it like this and I will resolve this afterwards. |
||
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); | ||
|
||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); | ||
|
@@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_tensor * k_cur, | ||
ggml_tensor * v_cur, | ||
ggml_tensor * kq_b, | ||
ggml_tensor * v_mla, | ||
float kq_scale, | ||
int il) const { | ||
GGML_UNUSED(n_tokens); | ||
|
@@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); | ||
//cb(k, "v", il); | ||
|
||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); | ||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); | ||
|
||
cb(cur, "kqv_out", il); | ||
|
||
|
@@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_tensor * k_cur, | ||
ggml_tensor * v_cur, | ||
ggml_tensor * kq_b, | ||
ggml_tensor * v_mla, | ||
float kq_scale, | ||
int il) const { | ||
// these nodes are added to the graph together so that they are not reordered | ||
|
@@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, | ||
0); | ||
|
||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); | ||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); | ||
cb(cur, "kqv_out", il); | ||
|
||
if (wo) { | ||
|
@@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_tensor * k_cur, | ||
ggml_tensor * v_cur, | ||
ggml_tensor * kq_b, | ||
ggml_tensor * v_mla, | ||
float kq_scale, | ||
int il) const { | ||
// these nodes are added to the graph together so that they are not reordered | ||
|
@@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn( | |
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); | ||
//cb(k, "v", il); | ||
|
||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); | ||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); | ||
|
||
cb(cur, "kqv_out", il); | ||
|
||
|
@@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling( | |
|
||
ggml_build_forward_expand(gf, cur); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be absorbed in the
cparams.yarn_attn_factor
and deduplicated from here andllm_build_deepseek2
?