Skip to content

Conversation

garrett361
Copy link
Contributor

@garrett361 garrett361 commented Jan 28, 2025

What does this PR do?

#33932 introduced FlashAttentionKwargs as an alternative to using position_ids for padding-free training. The RoPE positional embedding are not currently applied correctly in the FlashAttentionKwargs code path. This PR ensures that RoPE is used properly for this path.

Code Notes

The Issue

The issue is that if position_ids not provided, then they are internally generated here:

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

and these are used to generate the rope embeddings here:

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

These rope embeddings are ~ torch.arange, whereas they should be non-trivially generated from the values in FlashAttentionKwargs

The Fix

Introduce a get_position_ids_from_cu_seq_lens helper which coverts from FlashAttentionKwargs -> position_ids, when provided.

Because many other models inherit from LlamaDecoder, this change propagates changes to many other models via modular_model_converter.py.

Tests

The solution is tested in LlamaModelTest::test_attn_mask_position_ids_flash_attn_equality, which checks that logits in the follow cases are consistent with each other:

  • No padding-free, just padding and attention masks
  • Padding free via position_ids
  • Padding free via FlashAttentionKwargs

This test fails on latest main without the above fix.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 marked this pull request as draft January 28, 2025 17:35
@Rocketknight1
Copy link
Member

cc @Abhishek-TAMU @ArthurZucker because this is an update to #33932

@garrett361 garrett361 marked this pull request as ready for review February 3, 2025 15:47
@garrett361
Copy link
Contributor Author

@Abhishek-TAMU @ArthurZucker I removed the draft status and this work is ready for review. Please let me know if I can answer any questions about this PR. Thank you!

@Rocketknight1
Copy link
Member

cc @Cyrilvallez as well actually, since I think this relates to RoPE code you touched recently

@Cyrilvallez
Copy link
Member

Hey @garrett361! Very nice catch, this is indeed quite important! Here are a few thoughts/guidelines:

First, as the function is a helper directly related to FA2, it should be moved to modeling_flash_attention_utils.py.
Second, may I know a bit more about the setting in which you use this? I kind of feel that when using packed tensor format, it should be the responsibility of the user to feed correct inputs to the model. It would avoid one more code path in our modeling.
Depending on where/how you use this, I suspect it could call the function upstream, and then feed correct position_ids to the model.
Let me know what you think!

@garrett361 garrett361 mentioned this pull request Feb 6, 2025
5 tasks
@garrett361
Copy link
Contributor Author

Very nice catch, this is indeed quite important!

Thanks!

First, as the function is a helper directly related to FA2, it should be moved to modeling_flash_attention_utils.py.

Makes sense to me.

Second, may I know a bit more about the setting in which you use this?

I hit this in the process of testing #35861, which makes related changes for padding-free training. See there for some related discussion, as well.

I kind of feel that when using packed tensor format, it should be the responsibility of the user to feed correct inputs to the model. It would avoid one more code path in our modeling.

Agree it would be best if all of the needed data (both position_ids and cu_seq_len_x, in the present case) were computed and provided at the outset by, say, the dataloader. But there seem to be a bunch of different available code paths at the moment and am concerned about silent incorrectness issues if we don't have these kinds of helpers in the modeling code.

Examples:

  1. HF's DataCollatorWithFlattening only returns position_ids
  2. trl's DataCollatorForCompletionOnlyLM does what we want and returns all of {position_ids, cu_seq_lens_q, ...}.

What do you suggest for next steps?

@garrett361
Copy link
Contributor Author

First, as the function is a helper directly related to FA2, it should be moved to modeling_flash_attention_utils.py.
@Cyrilvallez @Rocketknight1 I moved the helper as requested.

Please let me know if any more is needed from my end!

@ArthurZucker
Copy link
Collaborator

Hey! Sorry, I agree with @Cyrilvallez and I think we should rather update / fix our Datacollator to make sure it passes the position ids and cu seqlens. We really don't want to add code that is specific to 1 integration path!

@ArthurZucker
Copy link
Collaborator

🤗

@garrett361
Copy link
Contributor Author

garrett361 commented Feb 14, 2025

Ok cool @ArthurZucker , so close this PR and adjust DataCollatorWithFlattening so that it returns {position_ids, cu_seq_lens_q, ...}?

@garrett361
Copy link
Contributor Author

Closing this: the intended padding-free code path with FlashAttentionKwargs is that both the FlashAttentionKwargs and position_ids are provided to the model.

I plan to open a separate PR which sanity checks this and raises a ValueError if only FlashAttentionKwargs are provided, along with making the FlashAttentionKwargs explicit, properly typed args.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants