Skip to content

Improvements in attention_forward functions #36218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mseeger
Copy link

@mseeger mseeger commented Feb 15, 2025

What does this PR do?

Improves eager_attention_forward and sdpa_attention_forward in the case when query.shape[1] > key.shape[1]. This happens in GQA (grouped query attention), and in particular in multi-head latent attention, as used in the recent DeepSeek models. The current implementations blow up key and value so that query.shape[1] == key.shape[1], which wastes memory.

Just to give the context. The recent DeepSeek models use 128 heads, while with multi-head latent attention (MLA), the key and value tensors do not have a head dimension at all. During inference, the largest tensors needed are these two. If this is not tackled (also for KV caching), you may need 128x as much memory on GPU as would be needed. As discussed, I'll also provide the new MLA inference code once DeepSeek models are there, but this needs the changes here as well.

A small caveat is that although the docs of torch.nn.functional.scaled_do_product_attention remains unclear whether their enable_gqa=True feature is widely supported. This is what is used here for SDPA.

  • Rewrite of eager_attention_forward in models/llama, and in all models this has just been copied to
  • Rewrite of sdpa_attention_forward to use enable_gqa=True if needed, instead of using repeat_kv
  • New test, comparing new against old implementation of eager_attention_forward
  • Also, removed position_ids argument in apply_rotary_pos_emb (again copied in many places)

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.

#35926

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@mseeger
Copy link
Author

mseeger commented Feb 15, 2025

#35926 adds DeepSeek-V3. Once this is in, I'll contribute an improvement of inference for this model. The changes here are part of this, but they also stand on their own.

@mseeger
Copy link
Author

mseeger commented Feb 15, 2025

It really is only src/transformers/integrations/sdpa_attention.py and src/transformers/models/llama/modeling_llama.py. All other changes are just copying these changes around.

@mseeger
Copy link
Author

mseeger commented Feb 16, 2025

The fix in tests/test_configuration_common.py and src/transformers/models/llava/configuration_llava.py is for a test which also fails in main, which is real odd.

@mseeger mseeger force-pushed the improve_attention branch 6 times, most recently from d2a1ed2 to 1f4e47d Compare February 17, 2025 11:27
@gante
Copy link
Member

gante commented Feb 17, 2025

Hey @mseeger 👋

I hope you don't mind, I've edited the PR header. It contained part of the template, and was pinging most of our team 🤗

@mseeger
Copy link
Author

mseeger commented Feb 17, 2025

Uh sorry

@mseeger
Copy link
Author

mseeger commented Feb 17, 2025

The test which fails seems flaky, it passes for me locally. I also do not see how it could have something to do with my changes

@mseeger
Copy link
Author

mseeger commented Mar 14, 2025

Please do not look so much on the files changed.

This is mostly due to copy&paste of the new eager_attention_forward code. Note there is a test confirming that the new code does the same as the old, without blowing up tensors explicitly. If matmul is implemented in a good way, this should run faster.

And the same for scaled_dot_product_attention, it should run faster if enable_gca=True, at least if this implementation is done properly.

@mseeger mseeger mentioned this pull request Mar 14, 2025
4 tasks
@mseeger mseeger force-pushed the improve_attention branch from 1f4e47d to 509a3f4 Compare March 30, 2025 20:05
@mseeger
Copy link
Author

mseeger commented Mar 30, 2025

@ArthurZucker . Now that #35926 has been merged, this is the first of two PRs which improve inference for the DeepSeek model, in the sense that MLA is really exploited, so that the KV cache takes much less space

@mseeger
Copy link
Author

mseeger commented Mar 30, 2025

This PR also cleans up things a bit, in that the most common form of eager_attention_forward (which this PR improves), is imported from one place (like all the other attention implementations) instead of being copied around.

@mseeger
Copy link
Author

mseeger commented Mar 31, 2025

Would reformat: src/transformers/models/led/modeling_led.py
Would reformat: src/transformers/models/led/modeling_tf_led.py

I don't get this failure. Locally, this works for me. In fact, it complains about these files in main and reformats them.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @mseeger ! Nice PR! Happy to improve deepseek's MLA!!

Though we need to keep the eager attention in each modeling file! This is because it's just our philosophy: simple basic usage should always be possible without having to import anything + should be explicit!
That's why we kept eager attn in all files!

Now let's isolate the changes to only update deepseek eager attn + update the integration of flex , sdpa etc to use GQA.

One thing, in terms of memory the repeat of kv should be nope, as it's just using the same memory pointers no?

Thanks! 🤗

@mseeger
Copy link
Author

mseeger commented Mar 31, 2025

Hello, I can revert it, but I note that the other attention implementations are not copied to every file, just the eager one …

@mseeger
Copy link
Author

mseeger commented Mar 31, 2025

About the repeat of KV: I am not sure. If you look into it:

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

After expand, I agree you just have a 0-stride. But the final reshape creates a copy. See comment below. And the copy has no information about the redundancy, so matmul will not be able to use it. On the other hand, if the inputs to matmul require broadcasting, matmul sees the special structure and will use it.

Finally, the change is really affecting how SDPA from Torch works, if you set enable_gqa=True. At least, it should. The docstring is funny, as if this was not fully implemented. I hope it is, because it is really important for MLA to work out.

@mseeger
Copy link
Author

mseeger commented Mar 31, 2025

Maybe this is why the DeepSeek folks only publish low level CUDA code for these things. I do hope the PyTorch creators take this serious and make sure their SDPA property supports enabla_gqa=True everywhere.

@mseeger
Copy link
Author

mseeger commented Apr 1, 2025

@ArthurZucker . I can replace the imports of eager_attention_forward and copy the code to all models. But given the test I added (which confirms old and new do the same without GQA), I'd prefer to make this change in all models that copied the llama version of eager_attention_forward (some models use different versions, I am not touching these), instead of just for DeepSeek.

These other models support GQA as well, and I don't want to be responsible for yet another copy of eager_attention_forward which mostly does the same as the variants before, just is simpler and maybe faster. OK?

@mseeger
Copy link
Author

mseeger commented Apr 1, 2025

I suppose in terms of other variants of attention being in integrations, I suppose you'd like to have the most basic one in each model file, but need to draw the line somewhere.

Have to say I am not a fan of this (makes contributions quite difficult), but I can understand it. Maybe some of your researcher customers are still using vim for writing their scripts and not an IDE.

@mseeger mseeger force-pushed the improve_attention branch from 509a3f4 to 4e7f364 Compare April 1, 2025 18:06
@mseeger
Copy link
Author

mseeger commented Apr 1, 2025

OK, copy and paste is restored

@mseeger mseeger force-pushed the improve_attention branch 3 times, most recently from d56942e to 833c4bf Compare April 6, 2025 07:36
@mseeger mseeger force-pushed the improve_attention branch from 833c4bf to 5015fe7 Compare April 6, 2025 07:52
@mseeger
Copy link
Author

mseeger commented Apr 6, 2025

@ArthurZucker , here is some code which shows that repeat_kv creates a new tensor:

>> a = torch.tensor([1, 2, 3, 4, 5])
>> b = a.view(1, -1).expand(5, -1)
>> b.shape
torch.Size([5, 5])
>> b.stride()
(0, 1)
>> c = b.view(-1)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
>> c = b.reshape(-1)  # Copy happens here
>> c.stride()
(1,)

expand works by using a 0 stride. 0-strides are very cool, just learned about them lately. But when a 0-stride dimension is fused with a normal dimension, this cannot be done with strides anymore, and a copy is drawn. Not only that, but the redundancy in the tensor is now not obvious anymore, and matmul will not be able to use it.

This should prove that my PR here makes sense. Using MLA, or GQA, without it (and using repeat_kv) really makes no sense, you'd not save anything.

@mseeger
Copy link
Author

mseeger commented Apr 16, 2025

Ping?

@liangan1
Copy link

@mseeger As you mentioned, the GQA has been critical for a lot of LLM models, may I know whether it is possible to split this PRs to enable different functionality? for example, 1 pr just for GQA in the sdpa_attention?

@mseeger
Copy link
Author

mseeger commented Jun 12, 2025

Hello @liangan1 , catering for GQA is what this PR is doing. Nothing more really

@liangan1
Copy link

Hello @liangan1 , catering for GQA is what this PR is doing. Nothing more really

Yes. Thanks for your info. I l know your point. What I want to say is that it seems this PR is too large to block the progress to merge.

@mseeger
Copy link
Author

mseeger commented Jun 12, 2025

@liangan1 , I can rebase this PR, pulling in recent changes. But in general, all non-flaky tests passed on this PR when I last worked on it. Sure, if these PRs stay open for months, things do not work anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants