-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: implementation varlen-flash-attention on cpu #3250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: implementation varlen-flash-attention on cpu #3250
Conversation
| @@ -0,0 +1,2235 @@ | |||
| use candle::Result; | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate of the above test, let me know how to not duplicate.
Ideally, I want each test to run for all of unfused, fp32 and fp16.
| use candle::Result; | ||
| use candle_nn::varlen_attention::flash_attn_varlen_unfused; | ||
|
|
||
| const FA_FEATURE_ENABLED: bool = false; // flash-attn features not available in this workspace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't have FA2 kernel in this repo, if i were to have them, I could run the below tests.
|
@EricLBuehler we collaborated on the flash-attn-3 kernels. Curious if you have high level feedback here. |
|
Hey Michael! Thanks for the PR :) We now have two different (both valuable and appreciated) cpu flash attention PRs open (ref), so I was wondering if I could ask you and @DrJesseGlass to cooperate a bit and figure out the best way forward 🙇 Other than trying to avoid large merge conflicts I'm also thinking about getting a somewhat unified interface (though the underlying implementations can be completely different) |
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Haven’t had time to test it myself yet, but here is a preliminary review 🫡
| } | ||
|
|
||
| match dt { | ||
| DType::F32 => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps easier to read if we dispatch to explicit f32/f16 functions?
At first I was considering a generic fn but we should just keep things simple :)
| let q_base = (q_idx * hq + h) * d; | ||
| let q_row = &q_data[q_base..q_base + d]; | ||
|
|
||
| // online softmax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think our current cpu softmax is online even. Should be. Same with cuda.
At least our metal softmax is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdym? Are you suggesting that w = (score - m).exp(). is not correct or that you think it would be helpful to have a online-softmax implementation for the cpu path in candle?
I think the above flash-attention cpu kernel (for non-varlen) also uses fused softmax to have memory-efficent-attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or that you think it would be helpful to have a online-softmax implementation for the cpu path in candle?
This ☝️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3269 I added a PR here, for visibility. Not yet ready. The current implementation in candle is indeed not online-softmax, but its also quite efficent on CPU still trying to figure out how to best measure the problem with the current one before marking it as ready.
|
@ivarflakstad I checked out the other flash attention PR. My goal is to make the modeling code for TEI faster, usually people submit a “padded” and a “varlen” implementation. I hope by providing a varlen primitive on CPU and metal, we can avoid having the padded implementation altogether. in flash-attn, there are two interfaces (sdpa like, as with @DrJesseGlass) and the varlen. I agree they should somehow live in the same files which would lead to conflicts. Would be great to know how the refactor the public api / which files to move them to. |
Well perhaps sdpa / varlen is a decent separation? |
|
I'm happy to integrate this once it is merged into a revised version of my PR (#3254). varlen attention would fit neatly into my proposed structure: My broader goal is to minimize standalone scripts and establish a clean attention module within candle-nn — with an API matching GPU flash-attn so CPU and GPU paths are interchangeable in transformers. Although my causal loop-bound implementation proves the value of specialized kernels (~3.5% speedup, ~14% memory reduction), it's an early version focused on API and layout agreement. Further optimizations are planned. Our work is complementary: while varlen can handle fixed-length as a special case, it's likely not as optimizable as a dedicated path. I'm thinking my causal implementation should specialize for B=1 (the common interactive case) while varlen handles batched workloads. This gives both paths room to optimize for their target use case. This loop-bound implementation here may also inform further optimizations on my side. |
|
@ivarflakstad Happy new year - can chance we can get this merged. Would love to get it in some kind of future text-embeddings-inference version. |
I would like native support for cuda-free fallback for variable length flash attention, so that the modeling code can be written for Metal, CPU and Cuda without requiring the kernel.
Actual Motivation: I want to get rid of the non-varlen formulation of the models. https://github.com/huggingface/text-embeddings-inference/tree/main/backends/candle/src/models Their main issue is that you need a varlen flash-attention primiteve. For that i need metal and cpu support.
The implementation is fully tested against what i have to use on cuda https://github.com/Dao-AILab/flash-attention/blob/ac9b5f107f2f19cd0ca6e01548d20d072a46335c/csrc/flash_attn/flash_api.cpp#L515
For a unfused impl, n o fused attention is used, "flash" fused, but allows for padding-free variable length attention on Metal and Cuda.
For cpu, I implemented a similar fashion a "flash-version", that one is however more or less auto-generated and peer reviewed by me.
I previously wrote the candle flash attn 3 bindings: https://github.com/michaelfeil/candle-flash-attn-v3.
I wrote up some benches and compare my flash-attn-cpu against the varlen formulation.
F16 Performance (milliseconds)