-
Notifications
You must be signed in to change notification settings - Fork 30.9k
fix to accept cumulative_seqlens from TransformersKwargs in FA #40194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix to accept cumulative_seqlens from TransformersKwargs in FA #40194
Conversation
cc @vasqu, can you check? We want to have one and only one name for the same objects everywhere, and avoid such warnings and reattributions of names. Probably we need to update the TransformerKwargs |
I think it is necessary to update |
@Cyrilvallez We have
Tbh, I'd be more pro making it more unified than having this workaround 😅 especially since it's more of a typing issue than a functional issue (dataclasses clash with what's really supposed to be passed as kwargs) |
Also, @vasqu can you make sure the current typing would be BC? If you changed it in your refactor, it may need to be update in |
Checked, it was introduced in #33932 and the signature of the function hasn't changed in regards to those kwargs. I.e. the dataclasses would need to be updated imo. Could search when the dataclass was changed but pretty sure sometime in all of the fa/kwarg changes the dataclass had the renamings. |
@Cyrilvallez I'm willing to contribute this, so I just need to create new PR to fix it? |
@Kurt232 Yes, that's correct |
@Kurt232 Either use this PR and revert the previous changes, or open a new PR, whichever you want |
cumulative_seqlens_q/k -> cu_seq_lens_q/k: - in the FlashAttentionKwargs in modeling_flash_attention_utils - in the TransformersKwargs in generic - in the PagedAttentionArgs in continuous_batching It is **BC**, because they are created in `ContinuousBatchProcessor.setup_static_tensors:L762`, used in `ContinuousBatchingManager._model_forward:L1233` and destroyed with `ContinuousBatchProcessor`
fedbf6d
to
dc0624d
Compare
At git@242bb2ca, in Please review @vasqu. Thx😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a 🚨 to the title, this is definitely breaking (API wise) for continuous batching (CB)
I'm not sure if CB is user-facing in this case so it might not be too bad but would check with @ArthurZucker
unused function arg in `PagedAttentionCache.update` Co-authored-by: Anton Vlasjuk <[email protected]>
cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | ||
of the sequences in the batch, used to index into q. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my main issue with this naming is that is is no helpful for newbies, cu does not mean anything!
Don't worry le't s just revert for continuous batching for now, the rest is fine! |
(we can remove the 🚨 as well then - only CB mightve been breaking) |
k, v = cache.update(k, v, module.layer_idx, **kwargs) | ||
|
||
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0) | ||
if implementation is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think paged_attention_forward
should use cu_seq_lens_q/k
instead of cumulative_seqlens_q/k
to keep coherency with flash_attn_varlen_func
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! I don't want to fight over this, its a small nit!
The helper _flash_attention_forward now falls back to the keys cumulative_seqlens_q/k that may arrive inside TransformersKwargs when the explicit cu_seq_lens_q/k arguments are absent
A warning is raised on conflict, ensuring users notice any override
What does this PR do?
I note that current transformer use TransformersKwargs as extra, but it will refuse the cumulative_seqlens_q and cumulative_seqlens_k args in flash_attention func, as unmatched argument name.
So I updated
_flash_attention_forward
to ensure compatibility with TransformersKwargs.Fixes #40193
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.