Skip to content

Conversation

Kurt232
Copy link
Contributor

@Kurt232 Kurt232 commented Aug 15, 2025

The helper _flash_attention_forward now falls back to the keys cumulative_seqlens_q/k that may arrive inside TransformersKwargs when the explicit cu_seq_lens_q/k arguments are absent

A warning is raised on conflict, ensuring users notice any override

What does this PR do?

I note that current transformer use TransformersKwargs as extra, but it will refuse the cumulative_seqlens_q and cumulative_seqlens_k args in flash_attention func, as unmatched argument name.

So I updated _flash_attention_forward to ensure compatibility with TransformersKwargs.

Fixes #40193

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@Kurt232 Kurt232 changed the title fix to accept cumulative_seqlens from TransformersKwargs in FA #40193 fix to accept cumulative_seqlens from TransformersKwargs in FA Aug 15, 2025
@Cyrilvallez
Copy link
Member

cc @vasqu, can you check? We want to have one and only one name for the same objects everywhere, and avoid such warnings and reattributions of names. Probably we need to update the TransformerKwargs

@Kurt232
Copy link
Contributor Author

Kurt232 commented Aug 18, 2025

I think it is necessary to update TransformerKwargs, and my fix is temporary before you update all members in TransformerKwargs and FlashAttentionKwargs.

@vasqu
Copy link
Contributor

vasqu commented Aug 18, 2025

@Cyrilvallez We have

  • the FlashAttentionKwargs in modeling_flash_attention_utils
  • the TransformersKwargs in generic
  • the PagedAttentionArgs in continuous_batching (+ some dependencies there for other methods 👀)

Tbh, I'd be more pro making it more unified than having this workaround 😅 especially since it's more of a typing issue than a functional issue (dataclasses clash with what's really supposed to be passed as kwargs)

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Aug 20, 2025

Yes, we don't really want this kind of workaround that clutters the code - @Kurt232 do you want to fix the typings where @vasqu mentionned instead? Otherwise we can do it 🤗
But TLDR, let's fix the root cause immediately, instead of first merging this and then reverting

@Cyrilvallez
Copy link
Member

Also, @vasqu can you make sure the current typing would be BC? If you changed it in your refactor, it may need to be update in flash_attention_forward instead of the Kwargs classes - otherwise it would plainly break BC for all downstream libs that used to pass them

@vasqu
Copy link
Contributor

vasqu commented Aug 20, 2025

Checked, it was introduced in #33932 and the signature of the function hasn't changed in regards to those kwargs. I.e. the dataclasses would need to be updated imo.

Could search when the dataclass was changed but pretty sure sometime in all of the fa/kwarg changes the dataclass had the renamings.

@Kurt232
Copy link
Contributor Author

Kurt232 commented Aug 20, 2025

@Cyrilvallez I'm willing to contribute this, so I just need to create new PR to fix it?
@vasqu I want to double check that I just need to rename the kwargs dataclass to fit the FA function signature? e.g. cumulative_seqlens_q/k -> cu_seq_lens_q/k.

@vasqu
Copy link
Contributor

vasqu commented Aug 20, 2025

@Kurt232 Yes, that's correct

@Cyrilvallez
Copy link
Member

@Kurt232 Either use this PR and revert the previous changes, or open a new PR, whichever you want

cumulative_seqlens_q/k -> cu_seq_lens_q/k:
- in the FlashAttentionKwargs in modeling_flash_attention_utils
- in the TransformersKwargs in generic
- in the PagedAttentionArgs in continuous_batching

It is **BC**, because they are created in `ContinuousBatchProcessor.setup_static_tensors:L762`, used in `ContinuousBatchingManager._model_forward:L1233` and destroyed with `ContinuousBatchProcessor`
@Kurt232 Kurt232 force-pushed the fix/args_in_flash_attention_forward branch from fedbf6d to dc0624d Compare August 21, 2025 08:17
@Kurt232
Copy link
Contributor Author

Kurt232 commented Aug 21, 2025

PagedAttention was added in git@211f2b0 and used in ContinuousBatchingManager in src/transformers/generation/continuous_batching.py.

At git@242bb2ca, in src/transformers/generation/continuous_batching.py (current):
I checked it is BC after renaming cumulative_seqlens_q/k -> cu_seq_lens_q/k.
Because they are created in ContinuousBatchProcessor.setup_static_tensors:L762, used in ContinuousBatchingManager._model_forward:L1233 and destroyed with ContinuousBatchProcessor.

Please review @vasqu. Thx😄

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a 🚨 to the title, this is definitely breaking (API wise) for continuous batching (CB)

I'm not sure if CB is user-facing in this case so it might not be too bad but would check with @ArthurZucker

@Kurt232 Kurt232 changed the title fix to accept cumulative_seqlens from TransformersKwargs in FA 🚨 fix to accept cumulative_seqlens from TransformersKwargs in FA Aug 21, 2025
unused function arg in `PagedAttentionCache.update`

Co-authored-by: Anton Vlasjuk <[email protected]>
Comment on lines +38 to 39
cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my main issue with this naming is that is is no helpful for newbies, cu does not mean anything!

@ArthurZucker
Copy link
Collaborator

Don't worry le't s just revert for continuous batching for now, the rest is fine!

@vasqu
Copy link
Contributor

vasqu commented Aug 22, 2025

(we can remove the 🚨 as well then - only CB mightve been breaking)

k, v = cache.update(k, v, module.layer_idx, **kwargs)

sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
if implementation is not None:
Copy link
Contributor Author

@Kurt232 Kurt232 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think paged_attention_forward should use cu_seq_lens_q/k instead of cumulative_seqlens_q/k to keep coherency with flash_attn_varlen_func.

https://github.com/huggingface/transformers/blob/29ddcacea3ad9d3cdf6c5d8e51d1d39cbc5e7dfa/src/transformers/modeling_flash_attention_utils.py#L557C1-L578C3

@Kurt232 Kurt232 changed the title 🚨 fix to accept cumulative_seqlens from TransformersKwargs in FA fix to accept cumulative_seqlens from TransformersKwargs in FA Aug 22, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! I don't want to fight over this, its a small nit!

@ArthurZucker ArthurZucker merged commit 14b89fe into huggingface:main Aug 25, 2025
20 of 22 checks passed
@Kurt232 Kurt232 deleted the fix/args_in_flash_attention_forward branch August 25, 2025 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

_flash_attention_forward can't receive cumulative_seqlens_q and cumulative_seqlens_k in TransformersKwargs from Inputs in forward

4 participants