Skip to content

DeepSeek V2/V3 MLA implementation #12801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 15, 2025

Conversation

jukofyork
Copy link
Collaborator

@jukofyork jukofyork commented Apr 7, 2025

This should hopefully be my final PR for this.

What is does:

  • (Backward Compatibility) Legacy non-MLA GGUF files can still be loaded and used as normal, but they won't benefit from the reduced KV-cache size MLA gives.
  • Adds context shifting ability (for both legacy non-MLA GGUF files and new MLA GGUF files).
  • It requires new GGUF files creating and even though the two split tensor names are the same; it will NOT work with files created for @fairydreaming's original MLA PR as these will be missing the new n_embd_head_k_mla/n_embd_head_v_mla metadata (see below).
  • It adds an optimised path for all MQA models (ie: GQA with a single group) inside of llm_graph_context::build_attn_mha(), which avoids the extra overhead of 3D batched matrix multiplication and just converts into normal 2D matrix multiplication (this is my only real contribution to this - the rest is all @fairydreaming's work!).

How it works:

1. Inside of convert_hf_to_gguf.py we alter the metadata to make the new MLA GGUF files appear to be MQA

def set_gguf_parameters(self):
    self.hparams["num_key_value_heads"] = 1
    super().set_gguf_parameters()
    .
    .
    .
    self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
    self.gguf_writer.add_value_length(hparams["kv_lora_rank"])

so for all intents and purposes, the rest of llama.cpp will now see these new GGUF files as being MQA:

print_info: n_head           = 16
print_info: n_head_kv        = 1
print_info: n_rot            = 64
print_info: n_embd_head_k    = 576
print_info: n_embd_head_v    = 512

We also add two new bits of metadata that we need to be able to "decompress" MQA back into MHA at the end:

    self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
    self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])

and these hold the original k_head_dim and v_head_dim of the model:

print_info: n_embd_head_k_mla    = 192
print_info: n_embd_head_v_mla    = 128

2. We add an extra tensor called v_mla to the parameters of llm_graph_context::build_attn_mha()

This gets used to "decompress" MQA back into MHA like so:

 // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
.
.
.
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
if (v_mla) {
    kqv = ggml_mul_mat(ctx0, v_mla, kqv);
}

This, and the function signature of llm_graph_context::build_attn() used to call this, are the only real changes needed to llama.cpp's code outside of the deepseek2-specific stuff in llama-model.cpp. I think this is the cleanest / most-maintainable way to add this MLA support.

3. When loading the tensors we will load only the legacy wkv_b tensor or the new split wk_b and wv_b tensors

const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
.
.
.
// note: only old legacy GGUF files will have the unsplit wkv_b tensor in
if (is_mla) {
    layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
    layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v_mla}, 0);
} else {
    layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0);
}

4. Inside of llm_build_deepseek2() we can treat as MHA as the old code did, or MQA as the new code

const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
.
.
.
if (is_mla) {
.
.
.

5. To get context-shifting to work

A. Ensure the RoPE part goes first and the NoPE part goes second
ggml_tensor * q_states = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);

and:

ggml_tensor * q_states = ggml_concat(ctx0, q_pe, q_nope, 0);
B. Apply the same scaling to yarn_attn_factor inside of llama_context::build_rope_shift()
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;

as is applied inside of llm_build_deepseek2():

// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));

NOTES

  • Unless you override the split attn_k_b and attn_v_b tensors to use BF16 in llama_tensor_get_type() you'll likely get pretty horrible CUDA performance. It seems that CUDA just does not like the 3D matrix multiplies these need when they are quantised.
  • If you use BF16 for these then it is also worth patching ggml/src/ggml-cuda/ggml-cuda.cu to use <= MMV_MAX_ROWS instead of < MMV_MAX_ROWS or just setting MMV_MAX_ROWS = 513. This gains me quite a bit by using the optimised MMV branch instead of the general CuBLAS branch.
  • I did try to add an -mla option like ik_llama.cpp uses, but it just ended up a real mess (see previious PR).
  • I diid try to keep the original wkv_b tensor and just slice it to gt wv_b, but it wasn't aligned properly and would have needed a copy/cont every time we accessed it, making keeping it pointless...
  • All credit to @fairydreaming for doing the work - all I've done here is tidy up the great work he did in his original PRs!

@ngxson Please don't merge this yet as I have left a placeholder function that needs to be removed first:

    // ****************************************************************************************************************
    // *** THIS WILL BE REMOVED AFTER CODE REVIEW IS ACCEPTED AND READY TO MERGE - IT'S JUST A COPY OF build_attn() ***
    // ****************************************************************************************************************
    ggml_tensor * build_attn_mla(
            llm_graph_input_attn_kv_unified * inp,
            ggml_cgraph * gf,
            ggml_tensor * wo,
            ggml_tensor * wo_b,
            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
            ggml_tensor * k_cur, // [n_embd_head_k, 1,        n_tokens]
            ggml_tensor * v_cur, // [n_embd_head_v, 1,        n_tokens]
            ggml_tensor * kq_b,
            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                  float   kq_scale,
                    int   il) const;

I don't want to make the changes to all the calls to build_attn() in llama-model.cpp until we're happy with the code or else it could end up a nightmare having to rebase it if others add new models whilst waiting for this to be reviewed!

I've tested this so far on:

  • deepseek-v2-lite: both legacy GGUF files and new GGUF files quantized and in BF16.
  • deepseek-r1: only the new GGUF files quantized and in BF16.

@github-actions github-actions bot added the python python script changes label Apr 7, 2025
@jukofyork
Copy link
Collaborator Author

Also, in case anybody want to use this with the -ot option, there is currently a bug stopping this:

#12798

but you can work around it by compiling with -DGGML_CUDA_GRAPHS=OFF.

@jukofyork
Copy link
Collaborator Author

The failures don't seem to be anything to do with me:

 E: Failed to fetch http://security.ubuntu.com/ubuntu/dists/noble-security/main/binary-amd64/Packages.xz  File has unexpected size (738752 != 738764). Mirror sync in progress? [IP: 91.189.91.83 80]

@bartowski1182
Copy link
Contributor

2 questions:

Will it automatically use the new MLA during conversion going forward, or do I need to enable a specific options?

Is this ready to start doing conversion/imatrix calculation or would I be wasting my time?

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 7, 2025

2 questions:

Will it automatically use the new MLA during conversion going forward, or do I need to enable a specific options?

It will just use MLA all the time now - the backward compatibility is only for old files (and the only way to use non-MLA now would be to run a version of convert_hf_to_gguf.py from before this PR).

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 7, 2025

Is this ready to start doing conversion/imatrix calculation or would I be wasting my time?

It depends on whether it's accepted, but the way of converting it to MQA inside of convert_hf_to_gguf.py seems quite clean to me and way less of a mess than my 2 pervious attempts at this; so quietly hopeful it wiill get accepted before another refactorathon breaks the PR! :D

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 7, 2025

It's not quite finished fine-tuning, but I should have some exciting news on the tiny draft models soon too:

image

The magenta line is the 0.6B model you get when you keep all 14 heads of qwen-2.5-instruct:0.5b (the extra 0.1B comes from the untied lm_head we create), and clearly shows it's not gaining much compared to the cyan line (which is a 12-headed 0.5B version of qwen-2.5-instruct:0.5b), and the ability to quantise the 12-headed version using K-quants wiill likely mean it performs better than a Q4_0 of the original 14-headed 0.6B model!

The gain between the grey line (8-headed 0.33B version) and the 12-headed version of ~8% eval top-1 does seem quite significant though, so I think this may actually be the optimal pruned size to use for speculative decoding (for all target models and not just deepkseek-r1!).

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 7, 2025

I forgot to turn flash attention off in the PR, so added that now:

    if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
        params.flash_attn = false;
    }

I don't think llama.cpp lets you use it when the K and V heads are different dimensions anyway, and even if it did the heads are far too large for it to work with the 576/512 dimension of the MQA version we turn MLA into in this PR (see #12227 for my failed attempt at this).

By forcing off like this, you should still be able to use -fa for the draft models (this worked in the last -mla option PR, but re-testing now).

@ggerganov ggerganov self-requested a review April 8, 2025 08:25
@jukofyork
Copy link
Collaborator Author

https://huggingface.co/jukofyork/DeepSeek-R1-DRAFT-0.5B-v1.0

https://huggingface.co/jukofyork/DeepSeek-R1-DRAFT-0.5B-v1.0-GGUF

I'm still waiting for somebody to show me a printout of the token IDs for the Unsloth quants as apparently they changed the <PAD> token ID for some reason, and these don't work because of that.

@jukofyork
Copy link
Collaborator Author

@ggerganov I've noticed that even with this added:

ddab5e4

    if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
        params.flash_attn = false;
    }

when using a draft model and -fa, the performance tanks.

I think somewhere in llama.cpp is reading the global flash_attn setting instead of the model-specific setting, which in turn is causing the the non-draft model to read this setting and have the v-cache transposed/untransposed, etc.

This isn't a problem with my PR, but probably should be looked at in the future (along with any other unwanted global parameters the non-draft model might be picking up).

@zts9989
Copy link

zts9989 commented Apr 9, 2025

when using a draft model and -fa, the performance tanks.

I found that the FA code update in b4759 (CUDA implementation) caused unexpected drastic performance changes in my test cases. This might be related.
#12816

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 12, 2025

I've got the draft models trained for DeepSeek-V3-0324 now:

https://huggingface.co/jukofyork/DeepSeek-V3-0324-DRAFT-0.5B-v1.0

https://huggingface.co/jukofyork/DeepSeek-V3-0324-DRAFT-0.5B-v1.0-GGUF

I used the same mix of data as for DeepSeek-R1 - which includes around 2.5B tokens of data generated by R1... The only difference is for this part of the dataset I removed the <think> tags. Even so, it looks to work really well, eg:

#!/bin/bash

host_address=192.168.1.2
port_number=8080

# Store the original directory
ORIGINAL_DIR=$(pwd)

# Change to the target directory
cd ~/llama.cpp_MLA/llama.cpp/build/bin

# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null

# Ask for permission to drop caches
read -p "Do you want to drop caches? (y/n) " -n 1 -r
echo    # Move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
    echo "Dropping caches..."
    echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null
fi

# Run the main command
./llama-server \
        --host "$host_address" \
        --port "$port_number" \
        --model ~/models/gguf/deepseek-v3-0324-mla-Q4_K_L+BF16.gguf \
        --alias "deepseek-v3-0324--Q4_K" \
        --chat-template deepseek3 \
        --n-gpu-layers 99 \
        --numa distribute \
        --override-tensor exps=CPU \
        --override-kv "deepseek2.expert_used_count=int:6" \
        --override-kv "deepseek2.expert_weights_scale=float:2.3" \
        --ctx_size 32768 \
        --batch-size 1024 \
        --ubatch-size 256 \
        --model-draft ~/models/gguf/draft_models/DeepSeek-V3-0324-DRAFT-0.5B-Q4_0.gguf \
        --top-k 1 \
        --samplers "top_k" \
        --gpu-layers-draft 99 \
        --draft-min 3 \
        --draft-max 32 \
        --draft-p-min 0.667

# Return to the original directory
cd "$ORIGINAL_DIR"

with the custom Q4_K_L+BF16 quant:

function safe_sed_function() {
    local file=$1
    local function_signature=$2
    local replacement=$3

    # Create backup
    cp "$file" "$file.bak"

    # Perform the replacement using address range and c command
    sed -i "${function_signature}/,/^}/c\\${replacement}" "$file"

    # Clean up
    rm "$file.bak"

    echo "Successfully replaced function in $file"
    echo "-------------------"
}

safe_sed_function "src/llama-quant.cpp" \
  "/^static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor \\* tensor, llama_ftype ftype) {" \
  "static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
    const std::string name = ggml_get_name(tensor);\n\
    if (name.find(\"_exps\") != std::string::npos) {\n\
        return GGML_TYPE_Q4_K;\n\
    } else if (name.find(\"attn_k_b\") != std::string::npos || name.find(\"attn_v_b\") != std::string::npos) {\n\
        return GGML_TYPE_BF16;\n\
    }\n\
    return GGML_TYPE_Q6_K;\n\
}"

I can generate over 11 tokens per second for refactoring tasks now on a machine with:

  • RTX 5000 ADA
  • Dual Xeon Gold 6248
  • 1.5TB 2666MT/s DDR4 (all 12 channels for each CPU populated using 64GB LRDIMMs).

and around 35-40 tokens per second prompt processing.

See #11446 (comment) for an explantion of why I'm using --override-kv "deepseek2.expert_weights_scale=float:2.3" at the same time as --override-kv "deepseek2.expert_used_count=int:6".

I've yet to really test how much quality is lost using 6 experts, the adjusted scale factor and Q4_K for non-shared expert tensors, but the raw PPL difference compared to full Q8_0 (with attn_k_b / attn_v_b as BF16) is less than 2%.

@jukofyork
Copy link
Collaborator Author

@ngxson @ggerganov @slaren This is ready for review!

I can keep it alive for now because of the copy of llm_graph_context::build_attn_mla() I made (which I will remove as soon as I get the OK), but sooner or later something is going to break it and I'm going to be away for the Easter holidays and won't be able to do much until the start of May to fix it (unless it's a super easy fix and can be done via a tablet) :/

Comment on lines 10128 to 10129
// TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
q_pe = ggml_cont(ctx0, q_pe);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was already investigated and the rope should work correctly now, so no longer needed to have ggml_cont() (see #12457 (comment)).

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version does not require changes to the KV cache implementation. Are there plans to update it in the future, or this is no longer needed?

Comment on lines +485 to +488
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be absorbed in the cparams.yarn_attn_factor and deduplicated from here and llm_build_deepseek2?

@ggerganov ggerganov requested a review from fairydreaming April 12, 2025 17:23
@jukofyork
Copy link
Collaborator Author

This version does not require changes to the KV cache implementation. Are there plans to update it in the future, or this is no longer needed?

No, it turns out if you convert the MLA stuff to being MQA right at the start, then all of the existing code for the KV-cache works without any changes.

If you don't convert to MQA and try to keep the ability to run as MHA, then it ended up a real mess and did need lots of changes.

@jukofyork
Copy link
Collaborator Author

On my workstation the token generation performance with above views and permutes to make q contiguous that you suggested is basically the same as in your original PR:

I think part of my performance regression may be something to do with the changes made to fix the -ot option:

#12798 (comment)

as I now get slightly better generation speed using -DGGML_CUDA_GRAPHS=OFF than I did before.

But I agree with @ggerganov that instead of adding views this probably should become an internal optimization of matrix multiplication in the CPU backend.

Yeah, the back-ends could likely detect even more cases like this where you can "collapse" batches and/or "fill-in" inner dimensions with size 1, etc.

@jukofyork
Copy link
Collaborator Author

@fairydreaming Could you test if there is any difference between having those permutes but not doing the views in build_attn_mha()? Does this perform any better or worse than the current master?


Also, can you try changing q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); to q_nope = ggml_cont(ctx0, ggml_permute(ctx0, q_nope, 0, 2, 1, 3)); to see if that improves things? I suspect the poor CUDA quantised wk_b performance comes from this permute and the effect on the following ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope).

It would be interesting to see if this helps your CPU setup as the data layout for that multiplication should be way more cache-friendly.

@fairydreaming
Copy link
Collaborator

fairydreaming commented Apr 16, 2025

@fairydreaming Could you test if there is any difference between having those permutes but not doing the views in build_attn_mha()? Does this perform any better or worse than the current master?

@jukofyork Generation is slightly better with permutes alone but only with short contexts. Initially it's about 0.7 t/s more for empty context but the difference quickly goes down to about 0.1 t/s at 8k. The difference in prompt processing performance is negligible.

Also, can you try changing q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); to q_nope = ggml_cont(ctx0, ggml_permute(ctx0, q_nope, 0, 2, 1, 3)); to see if that improves things? I suspect the poor CUDA quantised wk_b performance comes from this permute and the effect on the following ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope).

This change very slightly reduced the generation performance (about 0.1 t/s for short context sizes).

@jukofyork
Copy link
Collaborator Author

@fairydreaming Could you test if there is any difference between having those permutes but not doing the views in build_attn_mha()? Does this perform any better or worse than the current master?

@jukofyork Generation is slightly better with permutes alone but only with short contexts. Initially it's about 0.7 t/s more for empty context but the difference quickly goes down to about 0.1 t/s at 8k. The difference in prompt processing performance is negligible.

Thanks! If the backends do want to add the ability to squash dimensions, then we'll need to change the permutation order to be like this, so it's good that it doesn't have any negative effects when not using the view optimisations.

Also, can you try changing q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); to q_nope = ggml_cont(ctx0, ggml_permute(ctx0, q_nope, 0, 2, 1, 3)); to see if that improves things? I suspect the poor CUDA quantised wk_b performance comes from this permute and the effect on the following ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope).

This change very slightly reduced the generation performance (about 0.1 t/s for short context sizes).

Are you using a model with the wk_b tensor quantised too?

@fairydreaming
Copy link
Collaborator

Also, can you try changing q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); to q_nope = ggml_cont(ctx0, ggml_permute(ctx0, q_nope, 0, 2, 1, 3)); to see if that improves things? I suspect the poor CUDA quantised wk_b performance comes from this permute and the effect on the following ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope).

This change very slightly reduced the generation performance (about 0.1 t/s for short context sizes).

Are you using a model with the wk_b tensor quantised too?

@jukofyork it was a freshly converted one with this PR

@jukofyork
Copy link
Collaborator Author

Also, can you try changing q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); to q_nope = ggml_cont(ctx0, ggml_permute(ctx0, q_nope, 0, 2, 1, 3)); to see if that improves things? I suspect the poor CUDA quantised wk_b performance comes from this permute and the effect on the following ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope).

This change very slightly reduced the generation performance (about 0.1 t/s for short context sizes).

Are you using a model with the wk_b tensor quantised too?

@jukofyork it was a freshly converted one with this PR

Ah, does your CPU support BF16 natively? If so then it might be worth overriding the wk_b and wv_b tensors to use BF16 (there is a new PR for doing this or you can just hack llama_tensor_get_type() like this).

For CUDA the ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope) is much worse when using quantised types and BF16 gives quite a large boost for me (I'm unsure if this perm will help fix the quantised types yet though as not tested it).

@JohannesGaessler
Copy link
Collaborator

One issue with CUDA is that currently the support for non-contiguous tensors is much worse than with the CPU backend, particularly for quantized src0. I'm currently working on improving that since my impression with Deepseek V2 Lite is that the biggest performance bottlenecks have to do with MoE rather than the attention mechanism. I will make a PR that at least for batch size 1 brings the support for quantized data to the level of FP16/BF16. That will then also make it possible to handle GGML_OP_MUL_MAT_ID in a single kernel launch instead of having to do one per expert.

For batch sizes > 1 I'm not yet sure how to handle it; I could maybe adapt the MMQ code to handle the GGML_OP_MUL_MAT_ID in a single batch but that would double the amount of kernels that would need to be compiled and not be applicable to floating-point data.

@zts9989
Copy link

zts9989 commented Apr 18, 2025

Feedback for reference:

Deepseek R1 Q8_0

Configuration 1 (with MLA optimization):
CUDA_VISIBLE_DEVICES=0 /data/llama.cpp/llama.cpp-b5145/build/bin/llama-server -m /data/llm/deepR1-00001-of-00020.gguf -fa --temp 0.6 --top-p 0.95 -s 3047 --no-warmup -ngl 160 -c 14336 --host 0.0.0.0 -t 48 -ot exps=CPU -tb 96

...
init: kv_size = 14336, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 61, can_shift = 1
init: CUDA0 KV buffer size = 1814.75 MiB
llama_context: KV self size = 1814.75 MiB, K (f16): 960.75 MiB, V (f16): 854.00 MiB
...
slot print_timing: id 0 | task 0 |
prompt eval time = 616.97 ms / 6 tokens ( 102.83 ms per token, 9.72 tokens per second)
eval time = 5504.20 ms / 30 tokens ( 183.47 ms per token, 5.45 tokens per second)
total time = 6121.17 ms / 36 tokens
srv update_slots: all slots are idle

Configuration 2 (Original "PR" version with -ot ):
CUDA_VISIBLE_DEVICES=0,1 /data/fastll/llama.cpp2/build/bin/llama-server -m /data/llm/DeepSeek-R1.Q8_0-00001-of-00015.gguf -fa --temp 0.6 --top-p 0.95 -s 3047 --no-warmup -ngl 160 -c 14336 --host 0.0.0.0 -t 48 -ot exps=CPU -ts 30,32 -tb 96
...
llama_kv_cache_init: kv_size = 14336, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 61, can_shift = 0
llama_kv_cache_init: CUDA0 KV buffer size = 33600.00 MiB
llama_kv_cache_init: CUDA1 KV buffer size = 34720.00 MiB
llama_init_from_model: KV self size = 68320.00 MiB, K (f16): 40992.00 MiB, V (f16): 27328.00 MiB
...
slot print_timing: id 0 | task 0 |
prompt eval time = 404.80 ms / 6 tokens ( 67.47 ms per token, 14.82 tokens per second)
eval time = 2206.80 ms / 30 tokens ( 73.56 ms per token, 13.59 tokens per second)
total time = 2611.59 ms / 36 tokens
srv update_slots: all slots are idle

@jukofyork
Copy link
Collaborator Author

when using a draft model and -fa, the performance tanks.

I found that the FA code update in b4759 (CUDA implementation) caused unexpected drastic performance changes in my test cases. This might be related. #12816

I can confirm that using -fa for both the draft and target model seems to be broken - even just using it for qwen-coder-32b and qwen-coder-0.5b causes nans to be predicted by the draft model for certain prompts (but not all weirdly).

I haven't time currently to find the exact details though.

colout pushed a commit to colout/llama.cpp that referenced this pull request Apr 21, 2025
* Merged using squash to remove all noise commit messages

* Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large

* Removed 3 conts (2x RoPE and 1x RMS-norm)

* Changed to use `<cmath>` instead of `<math.h>`

* Reverted removal of the 3 conts

* Used `reshape` in `llm_graph_context::build_attn_mha()`

* Use `k_pe = ggml_reshape`

* Removed the 3 conts again

* Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF

* Removed MQA optimisation from `build_attn_mha()` as no gains now

* Simplified `is_mla` branch in `llm_build_deepseek2()`

* Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls

* Fixed call to `build_attn` in `llm_build_t5_enc`
@bartowski1182
Copy link
Contributor

having a weird issue with microsoft's R1 tune..

I converted and quantized with this patch applied, then calculated imatrix from the quantization, but now i can't apply to imatrix:

 [   4/1086]                blk.0.attn_k_b.weight - [  128,   512,   128,     1], type =   bf16, 

llama_tensor_get_type : tensor cols 128 x 512 are not divisible by 256, required for q6_K - using fallback quantization q8_0

====== llama_model_quantize_impl: imatrix size 128 is different from tensor size 16384 for blk.0.attn_k_b.weight
llama_model_quantize: failed to quantize: imatrix size 128 is different from tensor size 16384 for blk.0.attn_k_b.weight
main: failed to quantize model from '/models/MAI-DS-R1-GGUF/microsoft_MAI-DS-R1-bf16.gguf'

(ignore the tensor cols not divisible part, it's the fact that i get imatrix size is different from tensor size)

any idea what could be causing this? I assume it's from this change, but obviously I should probably roll back and verify if you think there's a chance it's unrelated

@saood06
Copy link

saood06 commented Apr 24, 2025

any idea what could be causing this?

The info in this seems relevant: ikawrakow/ik_llama.cpp#250

@bartowski1182
Copy link
Contributor

So it looks like there are 2 specific parts that could affect it

1 is in imatrix.cpp, changing:

for (int row = 0; row < (int)src1->ne[1]; ++row) {

to:

for (int row = 0; row < (int)(src1->ne[1]*src1->ne[2]); ++row) {

No idea what src1->ne[2] could possibly be..

The other is in llama.cpp (a section that's now in llama-quant.cpp) marked with MLA hack

Where we check if specifically the imatrix file is coming from an older conversion with standard attention, but that shouldn't be applicable to my case since i converted it myself with the new attention and then calculated the imatrix

I suppose it's possible that that first line is the fix required but have no idea what it is :')

@bartowski1182
Copy link
Contributor

The particularly odd part is that I'm able to inference it without issue, only imatrix is upset

@MB7979
Copy link

MB7979 commented Apr 24, 2025

I’m having severe degradation at long context since this commit. Using the newly quantized UD Unsloth quants (UD-IQ2_M and UD-Q2_K_XL). Quality is ok with short inputs but when I work with my standard 6-7K prompt I get garbled Chinese. This was not the case with the old pre-MLA UD-Q2_K_XL. Running with Q8_0 k cache oddly seems to resolve it, although quality is still not as good as pre-MLA.

@Panchovix
Copy link

I get the same issue as @MB7979, seems to be reproducible on different setups.

I have a Ryzen 7 7800X3D + 192GB RAM + 128GB VRAM (5090+4090x2+A6000), and I get gibberish normally at longer ctx, and for me -ctx q8_0 doesn't solve it.

https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD/discussions/2

@danielhanchen
Copy link
Contributor

danielhanchen commented Apr 28, 2025

@jukofyork Hey! Just wanted to ask if you know why CPU + QPU offloading with the MLA commit gives gibberish? See https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD/discussions/2#680f679eb63b7f85d975d8c5 and many other questions specifically on the MLA commit - I converted both R1 and D3 with the new MLA commit.

  1. If we offload everything to GPU and use Q4_0 Q and Q4_0 V cache via Flash Attention, there is no gibberish.
  2. If we offload non MoE or certain layers, then gibberish occurs.
  3. Some people have said using Q8_0 cache partially solves the issue.

I already:

  1. Re-uploaded all quants and re-did them over the weekend to see if it was my fault - V3 and R1 are re-uploaded, yet the problem persists. Full GPU offloading again works fine ie 8x H100s.
  2. Low precision Q and V cache works fine in my tests - it's just CPU offloading which is causing issues.

Thanks and really appreciate if you could investigate this especially since R2 might be around the corner!

See https://huggingface.co/unsloth/DeepSeek-R1-GGUF-UD, https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD and https://huggingface.co/unsloth/MAI-DS-R1-GGUF

@jukofyork
Copy link
Collaborator Author

@jukofyork Hey! Just wanted to ask if you know why CPU + QPU offloading with the MLA commit gives gibberish? See https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD/discussions/2#680f679eb63b7f85d975d8c5 and many other questions specifically on the MLA commit - I converted both R1 and D3 with the new MLA commit.

1. If we offload everything to GPU and use Q4_0 Q and Q4_0 V cache via Flash Attention, there is no gibberish.

2. If we offload non MoE or certain layers, then gibberish occurs.

3. Some people have said using Q8_0 cache partially solves the issue.

I already:

1. Re-uploaded all quants and re-did them over the weekend to see if it was my fault - V3 and R1 are re-uploaded, yet the problem persists. Full GPU offloading again works fine ie 8x H100s.

2. Low precision Q and V cache works fine in my tests - it's just CPU offloading which is causing issues.

Thanks and really appreciate if you could investigate this especially since R2 might be around the corner!

See https://huggingface.co/unsloth/DeepSeek-R1-GGUF-UD, https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD and https://huggingface.co/unsloth/MAI-DS-R1-GGUF

I'm away for another week so can't easily check this atm, but have you tried since this PR got merged earlier today:

#13137

This might fix it.

@JohannesGaessler
Copy link
Collaborator

I think I misdiagnosed the problem in that PR; I'll make a PR for a better fix soon. But only FP16 models should be affected in the first place.

@MB7979
Copy link

MB7979 commented Apr 28, 2025

Yes, that commit has not resolved the issue when using partial offload with the DeepSeek quants of R1 and V3 I’m using, unfortunately.

@Panchovix
Copy link

No luck either here on latest commit.

@Panchovix
Copy link

It seems the gibberish issue when offloading non experts is fixed since some days ago, if someone was wondering

https://huggingface.co/unsloth/DeepSeek-V3-0324-GGUF-UD/discussions/2#68192917c3d212ad5b33964d

@MB7979 may confirm

@MB7979
Copy link

MB7979 commented May 9, 2025

Yes, can confirm. Gibberish seems to be completely resolved now. Thanks to @JohannesGaessler for sorting that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.