Skip to content

feat: simplify metal reduce kernels and standardize on u32 indexing#3285

Merged
ivarflakstad merged 30 commits intomainfrom
sized-metal-reduce
Jan 18, 2026
Merged

feat: simplify metal reduce kernels and standardize on u32 indexing#3285
ivarflakstad merged 30 commits intomainfrom
sized-metal-reduce

Conversation

@drbh
Copy link
Contributor

@drbh drbh commented Jan 7, 2026

This PR improves the metal_reduce_.*_strided kernels

Benchmark

cargo bench --features metal -p candle-core --bench bench_main -- "metal_reduce_.*_strided"

Results

| Benchmark    | main                | new                      | Speedup      |
|--------------|---------------------|--------------------------|--------------|
| f32_strided  | 5.62ms (711 MiB/s)  | 2.3433 ms (1.6670 GiB/s) | 2.39x faster |
| f16_strided  | 3.98ms (503 MiB/s)  | 2.3631 ms (851.33 MiB/s) | 1.69x faster |
| bf16_strided | 3.95ms (505 MiB/s)  | 2.3545 ms (849.44 MiB/s) | 1.68x faster |

@drbh drbh requested a review from ivarflakstad January 7, 2026 00:00
@ivarflakstad
Copy link
Member

ivarflakstad commented Jan 7, 2026

Oh yeah big mr huggingface employee coming here showing the people how to vibe code kernel code with proper results, eh?

But seriously this is really good. We've previously talked about how indexing with smaller dtypes is surprisingly efficient, and I see you've also come up with a similar strided index improvement like I did here.
I find this part really interesting (and I suspect is a substantial part of the performance):

if (is_pow2[k]) {
  r = idx & masks[k];
  idx = idx >> shifts[k];
}

and I think it should be unrolled statically like the get_strided_index_1d / get_strided_idx<1> variants.
(edit: at least if all are pow2)

I couldn't find the branch but at some point I had a kernel impl where I calculated the min required index dtype based on the problem size using function constant values. Maybe worth exploring?

@ivarflakstad
Copy link
Member

ivarflakstad commented Jan 7, 2026

Just checked which kernels are actually being run in the benchmark: fast_sum_f32_strided, fast_sum_f16_strided, and fast_sum_bf16_strided. So we've only hit the IndexType::U32 => "" case and only profiled uint indexing.

@ivarflakstad
Copy link
Member

Good news and bad news.

Good news is that I managed to speed this up even more.

Bad news is that I've benchmarked the pow2 concept in isolation and there is no measurable performance benefit.
I'm assuming it is because using 3 lookups (is_pow2[k], masks[k], shifts[k]) counteracts the benefit of mask and shift operations. I even removed the is_pow2[k] check and there was still no measurable difference.
It seems all the gains are from using u32 and static unrolling.

Good news from the bad news - simplifies the improvements considerably!

@ivarflakstad
Copy link
Member

ivarflakstad commented Jan 13, 2026

Just did some quick mafs.
If the tensor is so large that u32 is not enough to index into it. That means that the underlying buffer has more than 4,294,967,295 elements in it. Even if we were to use fp8 for this tensor it would still be (4294967295 * 8 / (1024^3) =) 32 GB of data in a single tensor. For metal kernels that's kinda silly.
In other words I'm removing the kernels that use u64 to index. We'll use a different approach if the need arises.

@drbh
Copy link
Contributor Author

drbh commented Jan 13, 2026

Just did some quick mafs. If the tensor is so large that u32 is not enough to index into it. That means that the underlying buffer has more than 4,294,967,295 elements in it. Even if we were to use fp8 for this tensor it would still be (4294967295 * 8 / (1024^3) =) 32 GB of data in a single tensor. For metal kernels that's kinda silly. In other words I'm removing the kernels that use u64 to index. We'll use a different approach if the need arises.

thanks for diving into this! and yea I kinda assumed we could get away with just a u32 but didnt want to make too big of an assumption on todays hardware. however that being said, I do think its reasonable to avoid the u64 for now - and if this become a blocker in the future we can revisit

@drbh drbh changed the title feat: reduce prefers min dtype based on stride size feat: simplify metal reduce kernels and standardize on u32 indexing Jan 13, 2026
Copy link
Contributor Author

@drbh drbh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would approve the refactor/updates from @ivarflakstad if this wasnt originally opened by me!

looks great and agree the standardization on u32 is reasonable and realistic

**note just reran benches and get the following values on a MBP M3

Benchmark Time Thrpt
metal_reduce_f32_strided 2.3433 ms 1.6670 GiB/s
metal_reduce_f16_strided 2.3631 ms 846.33 MiB/s
metal_reduce_bf16_strided 2.3545 ms 849.44 MiB/s

@ivarflakstad ivarflakstad merged commit cc8ec5e into main Jan 18, 2026
10 checks passed
@ivarflakstad ivarflakstad deleted the sized-metal-reduce branch January 18, 2026 14:37
slckl pushed a commit to slckl/candle that referenced this pull request Mar 1, 2026
…uggingface#3285)

* feat: reduce prefers min dtype based on stride size

* fix: cargo format fix

* Add temporary large/small reduce benchmarks

* Improved / simplified strided indexing

* Begin untangling pow2 concept from indexer

* remove unused get_strided_index_u64

* Make indexer_t handle both cont and strided. indexer.last_dim is a constant

* Remove Pow2Meta

* Remove redundant contiguous_indexer

* Remove redundant strided index impl

* Remove pow2 from kernel call/signature

* Some size_t -> ulong changes

* Use indexer_t for all reduce based kernels

* contiguous indexer last_dim default is 0

* u16 indexing does not provide speedup, and is an unlikely use case

* Let indexer_t dictate indexing dtype

* Introduce finalize concept. Tidy up redundant fns/macros

* Tidying up

* Use store concept instead of finalize

* Simplify arg reduce macros

* Tidying up

* Add reduce kernels for large tensors. Use existing reduce macro to add implementations

Remove large arg reduce kernels as we currently only support writing uint arg reduce results.

* Remove u64 indexed reduce kernels (u32 should suffice). Also max block_dim is 1024, so removing 2048 case from reduce kernels

* Remove suffix from reduce macro

* Remove IDX from reduce macros

* Tidy up reduce kernel call code

* Explicit u32 in reduce call code. Remove IDX from arg reduce macros

* Remove small reduce benchmark

---------

Co-authored-by: Ivar Flakstad <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants