-
Notifications
You must be signed in to change notification settings - Fork 29.3k
Improvements in attention_forward functions #36218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
#35926 adds DeepSeek-V3. Once this is in, I'll contribute an improvement of inference for this model. The changes here are part of this, but they also stand on their own. |
It really is only |
5e0a820
to
7f1ff55
Compare
The fix in |
d2a1ed2
to
1f4e47d
Compare
Hey @mseeger 👋 I hope you don't mind, I've edited the PR header. It contained part of the template, and was pinging most of our team 🤗 |
Uh sorry |
The test which fails seems flaky, it passes for me locally. I also do not see how it could have something to do with my changes |
Please do not look so much on the files changed. This is mostly due to copy&paste of the new And the same for |
1f4e47d
to
509a3f4
Compare
@ArthurZucker . Now that #35926 has been merged, this is the first of two PRs which improve inference for the DeepSeek model, in the sense that MLA is really exploited, so that the KV cache takes much less space |
This PR also cleans up things a bit, in that the most common form of |
I don't get this failure. Locally, this works for me. In fact, it complains about these files in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mseeger ! Nice PR! Happy to improve deepseek's MLA!!
Though we need to keep the eager attention in each modeling file! This is because it's just our philosophy: simple basic usage should always be possible without having to import anything + should be explicit!
That's why we kept eager attn in all files!
Now let's isolate the changes to only update deepseek eager attn + update the integration of flex , sdpa etc to use GQA.
One thing, in terms of memory the repeat of kv should be nope, as it's just using the same memory pointers no?
Thanks! 🤗
Hello, I can revert it, but I note that the other attention implementations are not copied to every file, just the eager one … |
About the repeat of KV: I am not sure. If you look into it:
After Finally, the change is really affecting how SDPA from Torch works, if you set |
Maybe this is why the DeepSeek folks only publish low level CUDA code for these things. I do hope the PyTorch creators take this serious and make sure their SDPA property supports |
@ArthurZucker . I can replace the imports of These other models support GQA as well, and I don't want to be responsible for yet another copy of |
I suppose in terms of other variants of attention being in Have to say I am not a fan of this (makes contributions quite difficult), but I can understand it. Maybe some of your researcher customers are still using |
509a3f4
to
4e7f364
Compare
OK, copy and paste is restored |
d56942e
to
833c4bf
Compare
833c4bf
to
5015fe7
Compare
@ArthurZucker , here is some code which shows that
This should prove that my PR here makes sense. Using MLA, or GQA, without it (and using |
Ping? |
@mseeger As you mentioned, the GQA has been critical for a lot of LLM models, may I know whether it is possible to split this PRs to enable different functionality? for example, 1 pr just for GQA in the sdpa_attention? |
Hello @liangan1 , catering for GQA is what this PR is doing. Nothing more really |
Yes. Thanks for your info. I l know your point. What I want to say is that it seems this PR is too large to block the progress to merge. |
@liangan1 , I can rebase this PR, pulling in recent changes. But in general, all non-flaky tests passed on this PR when I last worked on it. Sure, if these PRs stay open for months, things do not work anymore. |
What does this PR do?
Improves
eager_attention_forward
andsdpa_attention_forward
in the case whenquery.shape[1] > key.shape[1]
. This happens in GQA (grouped query attention), and in particular in multi-head latent attention, as used in the recent DeepSeek models. The current implementations blow upkey
andvalue
so thatquery.shape[1] == key.shape[1]
, which wastes memory.Just to give the context. The recent DeepSeek models use 128 heads, while with multi-head latent attention (MLA), the key and value tensors do not have a head dimension at all. During inference, the largest tensors needed are these two. If this is not tackled (also for KV caching), you may need 128x as much memory on GPU as would be needed. As discussed, I'll also provide the new MLA inference code once DeepSeek models are there, but this needs the changes here as well.
A small caveat is that although the docs of
torch.nn.functional.scaled_do_product_attention
remains unclear whether theirenable_gqa=True
feature is widely supported. This is what is used here for SDPA.eager_attention_forward
inmodels/llama
, and in all models this has just been copied tosdpa_attention_forward
to useenable_gqa=True
if needed, instead of usingrepeat_kv
eager_attention_forward
position_ids
argument inapply_rotary_pos_emb
(again copied in many places)Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
#35926
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker