Skip to content

Conversation

@michaelfeil
Copy link
Contributor

@michaelfeil michaelfeil commented Dec 18, 2025

I would like native support for cuda-free fallback for variable length flash attention, so that the modeling code can be written for Metal, CPU and Cuda without requiring the kernel.

Actual Motivation: I want to get rid of the non-varlen formulation of the models. https://github.com/huggingface/text-embeddings-inference/tree/main/backends/candle/src/models Their main issue is that you need a varlen flash-attention primiteve. For that i need metal and cpu support.

The implementation is fully tested against what i have to use on cuda https://github.com/Dao-AILab/flash-attention/blob/ac9b5f107f2f19cd0ca6e01548d20d072a46335c/csrc/flash_attn/flash_api.cpp#L515

For a unfused impl, n o fused attention is used, "flash" fused, but allows for padding-free variable length attention on Metal and Cuda.
For cpu, I implemented a similar fashion a "flash-version", that one is however more or less auto-generated and peer reviewed by me.

I previously wrote the candle flash attn 3 bindings: https://github.com/michaelfeil/candle-flash-attn-v3.

I wrote up some benches and compare my flash-attn-cpu against the varlen formulation.

F16 Performance (milliseconds)

Shape Unfused Flash Varlen Flash Flash vs Unfused Varlen vs Flash Varlen vs Unfused
b4_hq8_hk8_d64_noncausal 35.11ms 16.15ms 1.63ms 2.17x faster 9.91x faster 21.54x faster
b4_hq8_hk8_d64_causal 35.11ms* ~16.15ms* ~1.63ms* 2.17x faster 9.91x faster 21.54x faster
b3_hq12_hk4_d64_gqa_noncausal 39.54ms 18.40ms 1.93ms 2.15x faster 9.53x faster 20.49x faster
b2_hq8_hk8_d64_alibi_causal ~35.11ms* ~16.15ms* ~1.63ms* 2.17x faster 9.91x faster 21.54x faster
F32 Performance (milliseconds)
Shape Unfused Flash Varlen Flash Flash vs Unfused Varlen vs Flash Varlen vs Unfused
------- --------- ------- --------------- ------------------ ------------------ -------------------
b4_hq8_hk8_d64_noncausal 37.54ms 10.15ms 1.66ms 3.70x faster 6.12x faster 22.61x faster
b4_hq8_hk8_d64_causal 37.54ms* ~10.15ms* ~1.66ms* 3.70x faster 6.12x faster 22.61x faster
b3_hq12_hk4_d64_gqa_noncausal 39.17ms 11.00ms 1.73ms 3.56x faster 6.36x faster 22.63x faster
b2_hq8_hk8_d64_alibi_causal ~37.54ms* ~10.15ms* ~1.66ms* 3.70x faster 6.12x faster 22.61x faster
le flat sampling, or reduce sample count to 50.
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_noncausal/fast: Collecting 100 samples in estima
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_noncausal/fast
                        time:   [1.7158 ms 1.7538 ms 1.7959 ms]
Found 6 outliers among 100 measurements (6.00%)
  4 (4.00%) high mild
  2 (2.00%) high severe
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_noncausal/unfused: Collecting 100 samples in est
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_noncausal/unfused
                        time:   [37.210 ms 37.677 ms 38.173 ms]
                        change: [−3.0355% −1.1008% +0.7660%] (p = 0.26 > 0.05)
                        No change in performance detected.
Found 6 outliers among 100 measurements (6.00%)
  5 (5.00%) high mild
  1 (1.00%) high severe

Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal/fast: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 6.5s, enable flat sampling, or reduce sample count to 60.
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal/fast: Collecting 100 samples in estimated
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal/fast
                        time:   [1.2776 ms 1.2971 ms 1.3181 ms]
Found 2 outliers among 100 measurements (2.00%)
  2 (2.00%) high mild
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal/unfused: Collecting 100 samples in estima
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal/unfused
                        time:   [40.374 ms 40.861 ms 41.377 ms]
                        change: [−3.7858% −1.9700% −0.1688%] (p = 0.04 < 0.05)
                        Change within noise threshold.
Found 6 outliers among 100 measurements (6.00%)
  5 (5.00%) high mild
  1 (1.00%) high severe

Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal_wl128/fast: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 6.7s, enable flat sampling, or reduce sample count to 60.
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal_wl128/fast: Collecting 100 samples in est
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal_wl128/fast
                        time:   [1.2898 ms 1.3063 ms 1.3235 ms]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Benchmarking cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal_wl128/unfused: Collecting 100 samples in 
cpu_varlen_prefill_F16_b4_hq8_hk8_d64_causal_wl128/unfused
                        time:   [42.432 ms 42.887 ms 43.361 ms]
                        change: [−6.2672% −4.6945% −3.0936%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild

Benchmarking cpu_varlen_prefill_F16_b3_hq12_hk4_d64_gqa_noncausal/fast: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 9.1s, enable flat sampling, or reduce sample count to 50.
Benchmarking cpu_varlen_prefill_F16_b3_hq12_hk4_d64_gqa_noncausal/fast: Collecting 100 samples in e
cpu_varlen_prefill_F16_b3_hq12_hk4_d64_gqa_noncausal/fast
                        time:   [1.8276 ms 1.8551 ms 1.8850 ms]
Found 2 outliers among 100 measurements (2.00%)
  2 (2.00%) high mild
Benchmarking cpu_varlen_prefill_F16_b3_hq12_hk4_d64_gqa_noncausal/unfused: Collecting 100 samples i
cpu_varlen_prefill_F16_b3_hq12_hk4_d64_gqa_noncausal/unfused
                        time:   [47.898 ms 48.537 ms 49.250 ms]
                        change: [+4.4476% +6.9434% +9.2754%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 7 outliers among 100 measurements (7.00%)
  3 (3.00%) high mild
  4 (4.00%) high severe

Benchmarking cpu_varlen_prefill_F16_b2_hq8_hk8_d64_alibi_causal/fast: Collecting 100 samples in est
cpu_varlen_prefill_F16_b2_hq8_hk8_d64_alibi_causal/fast
                        time:   [835.83 µs 848.73 µs 863.39 µs]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Benchmarking cpu_varlen_prefill_F16_b2_hq8_hk8_d64_alibi_causal/unfused: Collecting 100 samples in 
cpu_varlen_prefill_F16_b2_hq8_hk8_d64_alibi_causal/unfused
                        time:   [15.789 ms 15.993 ms 16.206 ms]
                        change: [−10.165% −8.3876% −6.4581%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high mild

Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_noncausal/fast: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 7.8s, enable flat sampling, or reduce sample count to 50.
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_noncausal/fast: Collecting 100 samples in estima
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_noncausal/fast
                        time:   [1.5073 ms 1.5342 ms 1.5632 ms]
                        change: [−18.364% −16.292% −14.233%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high mild
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_noncausal/unfused: Collecting 100 samples in est
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_noncausal/unfused
                        time:   [37.375 ms 37.812 ms 38.289 ms]
                        change: [+22.062% +24.144% +26.255%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 5 outliers among 100 measurements (5.00%)
  2 (2.00%) high mild
  3 (3.00%) high severe

Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal/fast: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 6.4s, enable flat sampling, or reduce sample count to 60.
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal/fast: Collecting 100 samples in estimated
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal/fast
                        time:   [1.2896 ms 1.3091 ms 1.3293 ms]
                        change: [−19.024% −16.568% −14.192%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
  7 (7.00%) high mild
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal/unfused: Collecting 100 samples in estima
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal/unfused
                        time:   [79.268 ms 83.888 ms 88.785 ms]
                        change: [+136.62% +151.95% +166.59%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 3 outliers among 100 measurements (3.00%)
  3 (3.00%) high mild

Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal_wl128/fast: Collecting 100 samples in est
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal_wl128/fast
                        time:   [2.6184 ms 2.9318 ms 3.2694 ms]
                        change: [+68.057% +87.891% +110.23%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 9 outliers among 100 measurements (9.00%)
  8 (8.00%) high mild
  1 (1.00%) high severe
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal_wl128/unfused: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 6.3s, or reduce sample count to 70.
Benchmarking cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal_wl128/unfused: Collecting 100 samples in 
cpu_varlen_prefill_F32_b4_hq8_hk8_d64_causal_wl128/unfused
                        time:   [50.712 ms 52.088 ms 53.498 ms]
                        change: [+41.970% +46.059% +50.326%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild

Benchmarking cpu_varlen_prefill_F32_b3_hq12_hk4_d64_gqa_noncausal/fast: Collecting 100 samples in e
cpu_varlen_prefill_F32_b3_hq12_hk4_d64_gqa_noncausal/fast
                        time:   [2.0540 ms 2.1434 ms 2.2444 ms]
                        change: [−3.3653% +0.9261% +5.7661%] (p = 0.71 > 0.05)
                        No change in performance detected.
Found 8 outliers among 100 measurements (8.00%)
  6 (6.00%) high mild
  2 (2.00%) high severe
Benchmarking cpu_varlen_prefill_F32_b3_hq12_hk4_d64_gqa_noncausal/unfused: Warming up for 3.0000 s
Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 5.0s, or reduce sample count to 90.
Benchmarking cpu_varlen_prefill_F32_b3_hq12_hk4_d64_gqa_noncausal/unfused: Collecting 100 samples i
cpu_varlen_prefill_F32_b3_hq12_hk4_d64_gqa_noncausal/unfused
                        time:   [48.690 ms 49.329 ms 50.066 ms]
                        change: [+17.648% +19.585% +21.727%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 3 outliers among 100 measurements (3.00%)
  2 (2.00%) high mild
  1 (1.00%) high severe

Benchmarking cpu_varlen_prefill_F32_b2_hq8_hk8_d64_alibi_causal/fast: Collecting 100 samples in est
cpu_varlen_prefill_F32_b2_hq8_hk8_d64_alibi_causal/fast
                        time:   [842.14 µs 851.89 µs 862.51 µs]
                        change: [−18.364% −16.554% −14.593%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
  2 (2.00%) low mild
  4 (4.00%) high mild
  5 (5.00%) high severe
Benchmarking cpu_varlen_prefill_F32_b2_hq8_hk8_d64_alibi_causal/unfused: Collecting 100 samples in 
cpu_varlen_prefill_F32_b2_hq8_hk8_d64_alibi_causal/unfused
                        time:   [16.212 ms 16.409 ms 16.614 ms]
                        change: [+19.997% +22.608% +25.066%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 4 outliers among 100 measurements (4.00%)
  4 (4.00%) high mild

@michaelfeil michaelfeil changed the title Mf/cpu-varlen-flash-attention feat: implementation varlen-flash-attention on cpu and as general purpose Dec 18, 2025
@@ -0,0 +1,2235 @@
use candle::Result;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate of the above test, let me know how to not duplicate.

Ideally, I want each test to run for all of unfused, fp32 and fp16.

use candle::Result;
use candle_nn::varlen_attention::flash_attn_varlen_unfused;

const FA_FEATURE_ENABLED: bool = false; // flash-attn features not available in this workspace
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't have FA2 kernel in this repo, if i were to have them, I could run the below tests.

@michaelfeil
Copy link
Contributor Author

@EricLBuehler we collaborated on the flash-attn-3 kernels. Curious if you have high level feedback here.

@michaelfeil michaelfeil changed the title feat: implementation varlen-flash-attention on cpu and as general purpose feat: implementation varlen-flash-attention on cpu Dec 18, 2025
@ivarflakstad
Copy link
Member

Hey Michael! Thanks for the PR :)

We now have two different (both valuable and appreciated) cpu flash attention PRs open (ref), so I was wondering if I could ask you and @DrJesseGlass to cooperate a bit and figure out the best way forward 🙇

Other than trying to avoid large merge conflicts I'm also thinking about getting a somewhat unified interface (though the underlying implementations can be completely different)

Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!
Haven’t had time to test it myself yet, but here is a preliminary review 🫡

}

match dt {
DType::F32 => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps easier to read if we dispatch to explicit f32/f16 functions?

At first I was considering a generic fn but we should just keep things simple :)

let q_base = (q_idx * hq + h) * d;
let q_row = &q_data[q_base..q_base + d];

// online softmax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think our current cpu softmax is online even. Should be. Same with cuda.
At least our metal softmax is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym? Are you suggesting that w = (score - m).exp(). is not correct or that you think it would be helpful to have a online-softmax implementation for the cpu path in candle?
I think the above flash-attention cpu kernel (for non-varlen) also uses fused softmax to have memory-efficent-attention.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or that you think it would be helpful to have a online-softmax implementation for the cpu path in candle?

This ☝️

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3269 I added a PR here, for visibility. Not yet ready. The current implementation in candle is indeed not online-softmax, but its also quite efficent on CPU still trying to figure out how to best measure the problem with the current one before marking it as ready.

@michaelfeil
Copy link
Contributor Author

@ivarflakstad I checked out the other flash attention PR. My goal is to make the modeling code for TEI faster, usually people submit a “padded” and a “varlen” implementation. I hope by providing a varlen primitive on CPU and metal, we can avoid having the padded implementation altogether.

in flash-attn, there are two interfaces (sdpa like, as with @DrJesseGlass) and the varlen. I agree they should somehow live in the same files which would lead to conflicts. Would be great to know how the refactor the public api / which files to move them to.

@ivarflakstad
Copy link
Member

in flash-attn, there are two interfaces (sdpa like, as with @DrJesseGlass) and the varlen. I agree they should somehow live in the same files which would lead to conflicts. Would be great to know how the refactor the public api / which files to move them to.

Well perhaps sdpa / varlen is a decent separation?
We should pull @EricLBuehler into this discussion as well since he wrote the original cpu impl :)

@DrJesseGlass
Copy link
Contributor

I'm happy to integrate this once it is merged into a revised version of my PR (#3254). varlen attention would fit neatly into my proposed structure:

attention/cpu_flash/
├── mod.rs          # Dispatch logic
├── standard.rs     # Existing mask-based (minimal changes)
├── causal.rs       # B=1 specialized, loop-bound
└── varlen.rs       # Variable-length batched (this work)

My broader goal is to minimize standalone scripts and establish a clean attention module within candle-nn — with an API matching GPU flash-attn so CPU and GPU paths are interchangeable in transformers.

Although my causal loop-bound implementation proves the value of specialized kernels (~3.5% speedup, ~14% memory reduction), it's an early version focused on API and layout agreement. Further optimizations are planned.

Our work is complementary: while varlen can handle fixed-length as a special case, it's likely not as optimizable as a dedicated path. I'm thinking my causal implementation should specialize for B=1 (the common interactive case) while varlen handles batched workloads. This gives both paths room to optimize for their target use case.

This loop-bound implementation here may also inform further optimizations on my side.

@michaelfeil
Copy link
Contributor Author

@ivarflakstad Happy new year - can chance we can get this merged. Would love to get it in some kind of future text-embeddings-inference version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants