Skip to content

DeepSeek V2/V3 with -mla option #12725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

jukofyork
Copy link
Collaborator

@jukofyork jukofyork commented Apr 2, 2025

This PR adds the -mla option (long name --mla-attn) and is a continuation of @fairydreaming's #11446 PR.

The quants created for @fairydreaming's PR should all still work fine, but you won't be able to use without requantising if you have a GGUF without the attn_k_b and attn_v_b tensors this or @fairydreaming adds.


I've set these two to use F32 for MLA and non-MLA respectively:

ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope);
ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32);
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
ggml_mul_mat_set_prec(kv, GGML_PREC_F32);

This has been tested to fix the weirdness I was getting for the non-MLA version which has wkv_b stored as Q8_0, but my MLA quant has wk_b stored as BF16 so can't test that yet.

These may cause some regression in performance as a result, but this is consistent with the way that llm_graph_context::build_attn_mha() handles this unconditionally:

        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

        // note: this op tends to require high floating point range
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
  • You can think of q_nope_absorbed as this kq calculation, but transferred to this operation instead of the actual kq calculation which now happens in the "compressed" space instead.
  • I'm not sure how to interpret the kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr) operation in terms of kq, but it can be tested to produce gibberish without this fix by just asking "Tell me a joke about pandas" and seeing it doesn't output any newlines and eventually just stops abruptly...

I've done my best to cleanly add the new code to the newly refactored llama-graph.cpp code, but sadly it's not possible to integrate with build_attn_mha() as:

  • There's no way to decompress using wv_b as build_attn_mha() applies the kqv_merged and cont inside.
  • We lose the performance gain from using 2D views when applying MQA (ie: GQA with 1 group).

To fix this I would have to start moving the wo stuff out of build_attn_mha() and/or create a new build_attn_mqa() function. (fixed now - see below)


This still looks a bit ugly in llama_kv_cache_unified::init():

    for (int i = 0; i < n_layer; i++) {
        int64_t n_embd_k;
        int64_t n_embd_v;

        // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group)
        if (cparams.mla_attn) {
            n_embd_k = hparams.n_lora_kv + hparams.n_rot;
            n_embd_v = hparams.n_lora_kv;
        } else {
            n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
            n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
        }

but I can't see a better way to do it currently.


I'll leave it as a draft for now and would welcome feedback, as I can't easily test this for other back-ends or without the newly added "offload tensor" stuff.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 3, 2025

I've tidied things up to use build_attn_mha() now, and made it so that build_attn_mha() will use the optimised 2D / non-batched version when it detects MQA for all models (not sure if any modern models use MQA though):

        // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
        if (n_head_kv == 1) {
            q = ggml_view_2d(ctx0, q,
                    n_embd, n_tokens*n_head,
                    ggml_row_size(q->type, n_embd),
                    0);
        }

        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

        // note: this op tends to require high floating point range
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

        if (n_head_kv == 1) {
            kq = ggml_view_3d(ctx0, kq,
                    n_kv, n_tokens, n_head,
                    ggml_row_size(kq->type, n_kv),
                    ggml_row_size(kq->type, n_kv)*n_tokens,
                    0);
        }
  • Calling build_attn_mha() from build_attn_mla() now means we can use --no-kv-offload with the --mla-attn option set.
  • I've tested using --cache-type-k q8_0 and --cache-type-k q4_0 at the same time as --mla-attn, and both seem to be working (although not tested to see how badly it degrades PPL, etc). It doesn't seem to have any real effect on the token generation speed (at low context at least).

I don't think I can really improve things much more, and the ugliness in llama_kv_cache_unified::init() and passing v_mha_proj to build_attn_mla() can't really be avoided, so I'll open this up for testing/review now (but please don't merge until we get some more feedback from people using other back-ends, etc).

@jukofyork jukofyork marked this pull request as ready for review April 3, 2025 03:23
@jukofyork jukofyork requested a review from ngxson as a code owner April 3, 2025 03:23
@CISC
Copy link
Collaborator

CISC commented Apr 3, 2025

The quants created for @fairydreaming's PR should all still work fine, but you won't be able to use without requantising if you have a GGUF without the attn_k_b and attn_v_b tensors this or @fairydreaming adds.

Can't we just derive these from attn_kv_b at load? Sure, a little tedious if it's quantized, but seems wasteful to essentially duplicate the tensors...

@ngxson
Copy link
Collaborator

ngxson commented Apr 3, 2025

The quants created for @fairydreaming's PR should all still work fine, but you won't be able to use without requantising if you have a GGUF without the attn_k_b and attn_v_b tensors this or @fairydreaming adds.

From user perspective, this can be a bad experience if user updates llama.cpp to latest version and suddenly the model no longer work.

I think what we can do is to add a new gguf metadata like deepseek2.attention.type = "mla" and only activate the "MLA" code via a condition. This will also eliminate the need of -mla switch, as most users don't even care what it is, let alone knowing it is a feature.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 3, 2025

I've found a subtle bug which only happens when you use a speculative decoding model:

    for (int i = 0; i < n_layer; i++) {
        int64_t n_embd_k;
        int64_t n_embd_v;

        // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group)
        if (cparams.mla_attn) {
            n_embd_k = hparams.n_lora_kv + hparams.n_rot;
            n_embd_v = hparams.n_lora_kv;
        } else {
            n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
            n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
        }

        const char * dev_name = "CPU";

        ggml_backend_buffer_type_t buft;
        if (offload) {
            auto * dev = model.dev_layer(i);
            buft = ggml_backend_dev_buffer_type(dev);

            dev_name = ggml_backend_dev_name(dev);
        } else {
            buft = ggml_backend_cpu_buffer_type();
        }

        LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %" PRId64 ", n_embd_v = %" PRId64 ", dev = %s\n", __func__,
                i, n_embd_k, n_embd_v, dev_name);

        ggml_context * ctx = ctx_for_buft(buft);
        if (!ctx) {
            LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
            return false;
        }

        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size);
        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size);
        ggml_format_name(k, "cache_k_l%d", i);
        ggml_format_name(v, "cache_v_l%d", i);
        k_l.push_back(k);
        v_l.push_back(v);
    }

The problem is that if (cparams.mla_attn) seems to be checking the global options rather than the model's own options.

This logic in llama-context.cpp seems to be able to set some model-specific options:

    if (params.mla_attn && model->arch != LLM_ARCH_DEEPSEEK2) {
        LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__);
        params.mla_attn = false;
    }

    if (params.flash_attn && params.mla_attn) {
        LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__);
        params.flash_attn = false;
    }

(at least from what is printed)

The solution for this case is just to check for cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2:

        if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
            n_embd_k = hparams.n_lora_kv + hparams.n_rot;
            n_embd_v = hparams.n_lora_kv;
        } else {
            n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
            n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
        }

But I also noticed that the cparams.flash_attn option is used twice like this inside llama-kv-cache.cpp:

    v_trans   = !recurrent && !cparams.flash_attn;
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
    // the FA kernels require padding to avoid extra runtime boundary checks
    return cparams.flash_attn ? 256u : 32u;
}   

and wonder if the same subtle bug is happening when using a flash attention allowed draft model with a non-allowed model like LLM_ARCH_GROK?

I think the llama_init_from_model() code which is showing these settings getting turn-off on a model-specific case when we have both a draft and a target model loaded:

    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
        params.flash_attn = false;
    }

might need investigating to check that setting -fa for a draft model that is allowed, isn't also turning it on the the main/target model unexpectedly and possibly having the main/target model start using a transposed V-cache without realising.

it was an extremely subtle bug to track down (I thought it was my speculative model that was broken until I ran it on its own and using llama-perplexity), and unless you suspect something is wrong and use the --verbose option to see what the speculative model is doing - you'll have no idea its just predicting garbage like this.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 3, 2025

I tested this ages ago and IIRC none of the weights even had a magnitude larger than ±256 so well within the range of float16's ±65504 - so the problem lies with the activation magnitudes, and if you think about it then it makes sense why at least wk_b causes this as you are essentially doing the KQ part of the attention calculation when you multiply Q by wk_b!

Right, so wouldn't that be a +1 for doing it at load?

Not sure what you mean? :)

We can't have it in the file as float16 or bfloat16 as default because a lot of the back-ends won't support these, and the rest of llama.cpp just defaults to using float32 in these cases for these reasons.

I don't really want to have some complex code that tries to guess the allowed operations for a given back-end to decide on the slicing type either as I think this will be way too brittle to be effective.

Whatever the final solution of the PR is I can say for sure I will be patching it to use bf16 myself for these two tensors, as it's clearly the best choice if they are running on CUDA in terms to speed and less weirdness in the outputs of the model! :)

@CISC
Copy link
Collaborator

CISC commented Apr 3, 2025

Not sure what you mean? :)
We can't have it in the file as float16 or bfloat16 as default because a lot of the back-ends won't support these, and the rest of llama.cpp just defaults to using float32 in these cases for these reasons.

I meant you can most likely quantize attn_kv_b to Q8_0 without (much) issue and then enforce full precision on wk_b.

Whatever the final solution of the PR is I can say for sure I will be patching it to use bf16 myself for these two tensors, as it's clearly the best choice if they are running on CUDA in terms to speed and less weirdness in the outputs of the model! :)

There are also some incoming updates on that for both CUDA and Vulkan. :)

@ngxson
Copy link
Collaborator

ngxson commented Apr 3, 2025

2. Slice these up and quantise them like this PR, leaving a quantised attn_kv_b, attn_k_b and attn_v_b in the file and do what @ngxson suggests and have some metadata added to detect if the attn_k_b and attn_v_b are present.

I agree that no one will want to use the non-MLA version, but the problem is that some users don't know what is MLA, and all they care about is that updating to a newer version of llama.cpp should not break their existing model.

Indeed, I'm thinking about another approach that is a bit stupidly simple: The current arch name is deepseek2, then why don't we simply add a new arch called deepseek3-mla with a MLA cgraph? Since MLA and non-MLA graph are very different, I think it's ok to separate them into 2 different archs.

Re. having attn_kv_b / attn_k_b and attn_v_b / both, there is an option to load tensor "optionally", if the tensor is not found in the file, its ggml_tensor * pointer will be set to NULL

Many archs are using this trick to detect if certain bias tensors are present or not, we can just do the same!

@jukofyork
Copy link
Collaborator Author

  1. Slice these up and quantise them like this PR, leaving a quantised attn_kv_b, attn_k_b and attn_v_b in the file and do what @ngxson suggests and have some metadata added to detect if the attn_k_b and attn_v_b are present.

I agree that no one will want to use the non-MLA version, but the problem is that some users don't know what is MLA, and all they care about is that updating to a newer version of llama.cpp should not break their existing model.

Indeed, I'm thinking about another approach that is a bit stupidly simple: The current arch name is deepseek2, then why don't we simply add a new arch called deepseek3-mla with a MLA cgraph? Since MLA and non-MLA graph are very different, I think it's ok to separate them into 2 different archs.

Re. having attn_kv_b / attn_k_b and attn_v_b / both, there is an option to load tensor "optionally", if the tensor is not found in the file, its ggml_tensor * pointer will be set to NULL

Many archs are using this trick to detect if certain bias tensors are present or not, we can just do the same!

@fairydreaming suggested this when he first put forward his MLA PR, but @ggerganov wasn't keen on the idea:

#11446 (comment)

but I agree it is an appealing solution that saves a lot of hassle?

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 3, 2025

Not sure what you mean? :)
We can't have it in the file as float16 or bfloat16 as default because a lot of the back-ends won't support these, and the rest of llama.cpp just defaults to using float32 in these cases for these reasons.

I meant you can most likely quantize attn_kv_b to Q8_0 without (much) issue and then enforce full precision on wk_b.

Yeah, this was exactly my first idea and I originally added this:

gguf.MODEL_TENSOR.ATTN_K_B,

but reverted it when I found even Q8_0 of the non-MLA's branch in attn_kv_b wasn't working properly without getting bumbed to use F32 for the matrix multiply.

Whatever the final solution of the PR is I can say for sure I will be patching it to use bf16 myself for these two tensors, as it's clearly the best choice if they are running on CUDA in terms to speed and less weirdness in the outputs of the model! :)

There are also some incoming updates on that for both CUDA and Vulkan. :)

Interesting!? I think there is a lot of potential for using BF16 but from my brief looking around in the CUDA code, I can see why @JohannesGaessler has been reluctant to add more of it due to all the complexity it will add for all the different cases, etc.

I have an RTX PRO 6000 Blackwell on order, so hope to retry some of the stuff with CUDA and BF16 when I get it.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 3, 2025

I'm not gonna rush to make any changes to the sliced tensors yet anyway and hopefully we can get more feedback on what is the best direction to go (just praying we don't have another refactoring like last time as the -mla option was a complete nightmare to rebase!).

@ngxson
Copy link
Collaborator

ngxson commented Apr 3, 2025

@fairydreaming suggested this when he first put forward his MLA PR, but @ggerganov wasn't keen on the idea:

#11446 (comment)

but I agree it is an appealing solution that saves a lot of hassle?

Ok sorry I was missing the context, thanks for pointing me to the correct discussion.

Both @fairydreaming and Georgi (not pinging you here to reduce a bit of noise) make good points, regarding the fact that MLA is not something "for free" but can affect the performance and memory usage in a way that we don't yet know.

So adding -mla option seems to be reasonable thing at the moment. This could allow experimenting with MLA / non-MLA without having to store 2x the model weight, very useful for someone who want to do benchmark.

But I still don't want to break existing quants. As deepseek v3 are huge, not everyone can requant it. I imagine even @bartowski1182 won't want this to happen, so let's try not to have a breaking change I guess?

@bartowski1182
Copy link
Contributor

It would certainly be nice if we could avoid breaking changes, but I'm of the opinion that if progress necessitates breakage, let's break it

Koboldcpp does work to maintain (too much) backwards compatibility for those that need it, people can run an older llama.cpp while waiting for update as well

By all means I would vastly prefer we keep it compatible, but I won't cry if they have to be remade so that we can avoid future maintenance headaches

Comment on lines -1474 to -1490
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are deleted inadvertently? For example, ffn_*_shexp are still used by qwen moe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these were all accidentally duplicated in the main branch so I removed the duplicates when inserting the new ones.

@jukofyork
Copy link
Collaborator Author

But I still don't want to break existing quants. As deepseek v3 are huge, not everyone can requant it. I imagine even @bartowski1182 won't want this to happen, so let's try not to have a breaking change I guess?

By all means I would vastly prefer we keep it compatible, but I won't cry if they have to be remade so that we can avoid future maintenance headaches

I'll try and explore some options over the weekend to allow for backwards compatibility.

If we're keeping the -mla option then the easiest would probably be the null pointer method mentioned earlier for bias weights (I think deepseek already uses this method for the q_proj matrices as the smaller "lite" versions don't use low-rank decomposed versions for these!). Then if someone tries to load an old quant using the -mla option just output a warning the tensors are missing and turn it off.

We can leave the actual final decision on what to do about the splitting, duplicates, f32, etc until after this gets fixed.

Comment on lines 4432 to 4436
return [
(self.map_tensor_name(name), data_torch),
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]
Copy link
Collaborator

@ngxson ngxson Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tell me if I missed something regarding the discussion about duplicate or not to duplicate these tensor (slices).

About the subject of not to duplicate it, I'm thinking about an idea that could allow slicing kv_b_proj at load without using too much memory, is to do something like this:

model.wk_b = nullptr; // at earlier load stage, we don't have this

// then during warmup
model.wk_b = ggml_view_2d(wkv_b, qk_nope_head_dim,...);
model.wv_b = ggml_view_2d(wkv_b, v_head_dim,...);

// transpose wk_b then copy back to initial memory
ggml_tensor * wk_b_T = ggml_cont(ggml_transpose(model.wk_b));
model.wk_b = ggml_cpy(wk_b_T, model.wk_b); // not even sure if this would work

The ggml_cpy and ggml_view won't allocate new memory on device buffer, only ggml_cont need to allocate memory.

Copy link
Collaborator Author

@jukofyork jukofyork Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we could easily do this (or something similar) so long as we keep kv_proj_b as a float32, but this has different problems:

  • We're now forcing those not using the -mla option to have kv_proj_b stored as float32 when they don't really need it and can just use the ggml_mul_mat_set_prec(xxx, GGML_PREC_F32) call instead.
  • This won't work for existing quantised models as they are stored row-major in memory and we'd have to dequantise, requantise and hope the alignment is right (and it will have a row length of 128 so also won't work on any of the non-legacy quants).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This raises a good point though: I don't think we really need to save wv_b at all and can just use the upper slice of wkv_b (I think - will have to check tomorrow).

@jukofyork
Copy link
Collaborator Author

Setting this back to draft whilst I make the changes.

@jukofyork jukofyork marked this pull request as draft April 4, 2025 19:24
@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 4, 2025

@ngxson (and others)

What do you think of this now I've removed the extra attn_v_b completely, and get it from the old attn_kv_b:

                    // {n_embd_head_v, n_head, n_tokens}
                    ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wkv_b,
                            kv_lora_rank, n_embd_head_v, n_head,
                            ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank),
                            ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * (n_embd_head_qk_nope + n_embd_head_v),
                            ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * n_embd_head_qk_nope);
                    cb(wv_b, "wv_b", il);

Then we add a attn_k_b_trans only.

I've tested this on v2-lite for the old GGUF files and those newly created from this, and all seems to be working fine.

You get this message if you try to use -mla on the old GGUF files:

mla_attn requires a gguf with the new 'attn_k_b_trans' tensor - forcing off

I'm not sure what the performance impact of using the wv_b view is yet, nor if we really need all the ggml_mul_mat_set_prec(XXX, GGML_PREC_F32) calls. (at least for the "lite" models; neither need this - see below)

I'm quantising the full r1 overnight and will test this tomorrow, and if all OK open the PR back up for review.

@Panchovix
Copy link

Panchovix commented Apr 4, 2025

@jukofyork sorry to bother, but any chance to try to do a small quant of V3? Like Q2_K_XL size or similar. I want to test CPU (192GB RAM) + 4 CUDA GPUs (128GB VRAM), since when I use -mla from ik_llamacpp fork, I get gibberish output, ikawrakow/ik_llama.cpp#305 (and tested with a model quanted with ik_llamacpp, DeepSeek-V3-0324-IQ2_K_R4).

When not using the flag, works fine.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 5, 2025

DeepSeek-V2-Lite-Chat

These are all after I removed the ggml_mul_mat_set_prec() call for both the MLA and non-MLA cases:

with -mla

BF16: PPL = 8.8450 +/- 0.06679
Q8_0: PPL = 8.8603 +/- 0.06692
Q4_K_S: PPL = 9.2355 +/- 0.06992
Q2_K: PPL = 52.7166 +/- 0.53008

without -mla

BF16: PPL = 8.8436 +/- 0.06678
Q8_0: PPL = 8.8614 +/- 0.06695
Q4_K_S: PPL = 9.2304 +/- 0.06984
Q2_K: PPL = 58.3260 +/- 0.59799

so at least for the "lite" version there is no justification for ggml_mul_mat_set_prec (I also tested --batch-size 1 for a few chunks to be sure and it too was fine with no crazy PPL numbers).


It will be tomorrow before the new r1 quants are ready to test.

@jukofyork
Copy link
Collaborator Author

@jukofyork sorry to bother, but any chance to try to do a small quant of V3? Like Q2_K_XL size or similar. I want to test CPU (192GB RAM) + 4 CUDA GPUs (128GB VRAM), since when I use -mla from ik_llamacpp fork, I get gibberish output, ikawrakow/ik_llama.cpp#305 (and tested with a model quanted with ik_llamacpp, DeepSeek-V3-0324-IQ2_K_R4).

When not using the flag, works fine.

I'll try and test it tomorrow, but at least for the "lite" model above; Q2_K was equally bad.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 5, 2025

This works, but gives truly horrible performance on the CUDA back-end:

with -mla and attn_kv_b.weight and attn_k_b_trans.weight quantised to Q8_0:

1.80 tokens per second

It seems that the code in this branch in ggml-cuda.cu's ggml_cuda_mul_mat() just isn't optimised to handle batches shaped like this:

    } else if (use_mul_mat_vec_q) {
        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
    }

with -mla and attn_kv_b.weight and attn_k_b_trans.weight quantised to BF16:

4.50 tokens per second

with < MMV_MAX_ROWS changed to <= MMV_MAX_ROWS in ggml_cuda_mul_mat():

    if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
        // the custom F16 vector kernel can be used over batched cuBLAS GEMM
        // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
        ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
    }

For comparison:

without -mla and using --no-kv-offload

3.12 tokens per second

So the problem isn't with the actual quantised attn_kv_b.weight when used as it is in the non-MLA branch; it's purely when it gets used as a batched matrix-vector multiply in the MLA branch.

After seeing this I also wonder if @fairydreaming original findings of MLA having slightly worse performance were just because he was using quantised attn_k_b.weight and attn_v_b.weight?


We can't use F16 even if we wanted as this overflows:

    ggml_tensor * wk_b_trans = ggml_view_3d(ctx0, model.layers[il].wk_b_trans,
        n_embd_head_qk_nope, kv_lora_rank, n_head,
        ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope),
        ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope) * kv_lora_rank,
        0);

I've tried all sorts of things to see if I can get this to work with F16 too: absorbing the scales into the layer_norm, etc, but had no luck. You actually get quite a boost to the CUDA backend for prompt processing if this was possible, as it then drops into this case:

    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
        // general KQ + KQV multi-batch without FlashAttention
        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
    }

but even if I could get this to work; we can't really use F16 or BF16 for this as not all back-ends will support it....


So our only two options are:

  • Leave it like this and have horrible -mla performance and potential numerical problems that don't really show up with llama-perplexity, but do definitely seem to make r1's replies shorter and dumber.
  • Changing it all back to use separate attn_k_b.weight and attn_v_b.weight stored as F32 and have slightly less horrible performance and guaranteed no numerical problems.

For me this makes no difference, as I will be bumping both these to use BF16 when I make my own quants, but I can just hear the groans of "llama.cpp's MLA implementation is broken" on Reddit already if we leave it like this :/

@jukofyork
Copy link
Collaborator Author

I'll make a branch with this alternative so other's can compare:

  • Changing it all back to use separate attn_k_b.weight and attn_v_b.weight stored as F32 and have slightly less horrible performance and guaranteed no numerical problems.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 5, 2025

I'll make a branch with this alternative so other's can compare:

  • Changing it all back to use separate attn_k_b.weight and attn_v_b.weight stored as F32 and have slightly less horrible performance and guaranteed no numerical problems.

It's here: https://github.com/jukofyork/llama.cpp/tree/mainline-llama-cpp-master--mla--f32

but this adds around 4.1GB (512 × 32768 × 4 × 61) to the model so I don't think this is acceptable either...

I say we just live with the horrible CUDA performance for now and go with @JohannesGaessler's suggestion:

MMV_MAX_ROWS is only a rough estimate and it's fine to adjust it.

More generally, my advice for this PR would be not to focus too much on CUDA support out of the gate and to prioritize getting the CPU code in order. My biggest limitation is that I'm chronically short on time since I'm currently doing a PhD in physics completely unrelated to language models. I don't want to invest too much time into model-specific features when the more fundamental features in llama.cpp still have major issues; if you present me with more or less finalized CPU support for Deepseek however, I would be willing to invest more time into the equivalent CUDA support.

and when he gets time he can maybe figure out what is causing this for quantised batched matrix multiplies like {xxx, yyy, 128}.

I also can't prove the ggml_mul_mat_set_prec are needed, and even if they are then this can easily be added in the future and won't effect any existing or newly created MLA-quants with the extra attn_k_b_trans.weight tensors.

The newly created MLA-quants will also have the extra attn_k_b_trans.weight tensors "auto-bumped" in quant-quality anyway (due to being 128 row-size and not allowing for K-quants), but if this is not enough then maybe @bartowski1182 can look into this in his #12727 PR.

Opening this PR back up for review, as I don't think I can really do any better than this.

@jukofyork jukofyork marked this pull request as ready for review April 5, 2025 10:56
@jukofyork
Copy link
Collaborator Author

jukofyork commented Apr 5, 2025

As a workaround for the CUDA performance, then you should be able to adapt this script for your own personal quants:

#!/bin/bash

function safe_sed() {
    local file=$1
    local pattern=$2
    local replacement=$3

    # Check if pattern exists
    if ! sed -n "s/${pattern}/${replacement}/p" "$file" | grep -q .; then
        echo "Error: Pattern not found in $file: $pattern"
        return 1
    fi

    # Create backup
    cp "$file" "$file.bak"

    # Perform the replacement
    sed -i "s/${pattern}/${replacement}/g" "$file"

    # Show diff
    echo "Changes in $file:"
    diff "$file.bak" "$file"

    # Clean up
    rm "$file.bak"

    echo "Successfully replaced in $file"
    echo "-------------------"
}

function safe_sed_function() {
    local file=$1
    local function_signature=$2
    local replacement=$3

    # Create backup
    cp "$file" "$file.bak"

    # Perform the replacement using address range and c command
    sed -i "${function_signature}/,/^}/c\\${replacement}" "$file"

    # Clean up
    rm "$file.bak"

    echo "Successfully replaced function in $file"
    echo "-------------------"
}

rm -rf llama.cpp

git clone https://github.com/jukofyork/llama.cpp --branch mainline-llama-cpp-master--mla
cd llama.cpp

# For attn_v_b to use fast mmv call.
safe_sed "ggml/src/ggml-cuda/ggml-cuda.cu" "< MMV_MAX_ROWS" "<= MMV_MAX_ROWS"

# Don't offload these huge tensors to GPU as PCI-E transfer is slower than just just using CPU.
safe_sed "ggml/src/ggml-cuda/ggml-cuda.cu" "const int min_batch_size = 32" "const int min_batch_size = 9999999"

# Hack llama_tensor_get_type() to use our custom quant.
safe_sed_function "src/llama-quant.cpp" \
  "/^static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor \\* tensor, llama_ftype ftype) {" \
  "static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
    const std::string name = ggml_get_name(tensor);\n\
    if (name.find(\"attn_kv_b\") != std::string::npos || name.find(\"attn_k_b_trans\") != std::string::npos) {\n\
        return GGML_TYPE_BF16;\n\
    }\n\
    return GGML_TYPE_Q8_0;\n\
}"

cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON
cmake --build build --config Release -- -j 44

You may not want to patchmin_batch_size value or may want to try to set it to something better than 9999999 (I only use this because my GPUs are using PCI-E 3.0 x16 and it's slower to pull these ~2GB tensors through the PCI-E bus than just run on CPU unless the batch size is huge. Unless you are using the new -ot exps=CPU then this should be removed completely...


I'm just checking now what difference (if any) this runtime slicing of wkv_b makes after the BF16 patches in the script above are applied:

// {n_embd_head_v, n_head, n_tokens}
ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wkv_b,
    kv_lora_rank, n_embd_head_v, n_head,
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank),
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * (n_embd_head_qk_nope + n_embd_head_v),
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * n_embd_head_qk_nope);

@jukofyork
Copy link
Collaborator Author

So it turns out you can't do this slice anyway:

// {n_embd_head_v, n_head, n_tokens}
ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wkv_b,
    kv_lora_rank, n_embd_head_v, n_head,
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank),
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * (n_embd_head_qk_nope + n_embd_head_v),
    ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * n_embd_head_qk_nope);
llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:75: CUDA error
CUDA error: misaligned address
  current device: 0, in function ggml_backend_cuda_synchronize at llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2437
  cudaStreamSynchronize(cuda_ctx->stream())

and if you try to cont the wv_b view then you get this:

CUDA error: an illegal memory access was encountered

Sorry guys, but I can't waste any more time on this as each of these changes is taking several hours to re-quant all the models to test with, so I'm just gonna go back to storing a copy of attn_v_b (as @fairydreaming's PR did) and open a new final PR with this in.

@jukofyork
Copy link
Collaborator Author

Continued in #12772

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants