-
Notifications
You must be signed in to change notification settings - Fork 30.9k
fix FlashAttentionKwargs RoPE #35941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix FlashAttentionKwargs RoPE #35941
Conversation
cc @Abhishek-TAMU @ArthurZucker because this is an update to #33932 |
@Abhishek-TAMU @ArthurZucker I removed the |
cc @Cyrilvallez as well actually, since I think this relates to RoPE code you touched recently |
Hey @garrett361! Very nice catch, this is indeed quite important! Here are a few thoughts/guidelines: First, as the function is a helper directly related to FA2, it should be moved to |
Thanks!
Makes sense to me.
I hit this in the process of testing #35861, which makes related changes for padding-free training. See there for some related discussion, as well.
Agree it would be best if all of the needed data (both Examples:
What do you suggest for next steps? |
Please let me know if any more is needed from my end! |
Hey! Sorry, I agree with @Cyrilvallez and I think we should rather update / fix our Datacollator to make sure it passes the position ids and cu seqlens. We really don't want to add code that is specific to 1 integration path! |
🤗 |
Ok cool @ArthurZucker , so close this PR and adjust DataCollatorWithFlattening so that it returns |
Closing this: the intended padding-free code path with I plan to open a separate PR which sanity checks this and raises a |
What does this PR do?
#33932 introduced
FlashAttentionKwargs
as an alternative to usingposition_ids
for padding-free training. The RoPE positional embedding are not currently applied correctly in theFlashAttentionKwargs
code path. This PR ensures that RoPE is used properly for this path.Code Notes
The Issue
The issue is that if
position_ids
not provided, then they are internally generated here:transformers/src/transformers/models/llama/modeling_llama.py
Lines 561 to 562 in ec7afad
and these are used to generate the rope embeddings here:
transformers/src/transformers/models/llama/modeling_llama.py
Lines 570 to 571 in ec7afad
These rope embeddings are
~ torch.arange
, whereas they should be non-trivially generated from the values inFlashAttentionKwargs
The Fix
Introduce a
get_position_ids_from_cu_seq_lens
helper which coverts fromFlashAttentionKwargs -> position_ids
, when provided.Because many other models inherit from
LlamaDecoder
, this change propagates changes to many other models viamodular_model_converter.py
.Tests
The solution is tested in
LlamaModelTest::test_attn_mask_position_ids_flash_attn_equality
, which checks that logits in the follow cases are consistent with each other:position_ids
FlashAttentionKwargs
This test fails on latest
main
without the above fix.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.