feat: simplify metal reduce kernels and standardize on u32 indexing#3285
feat: simplify metal reduce kernels and standardize on u32 indexing#3285ivarflakstad merged 30 commits intomainfrom
Conversation
|
Oh yeah big mr huggingface employee coming here showing the people how to vibe code kernel code with proper results, eh? But seriously this is really good. We've previously talked about how indexing with smaller dtypes is surprisingly efficient, and I see you've also come up with a similar strided index improvement like I did here. if (is_pow2[k]) {
r = idx & masks[k];
idx = idx >> shifts[k];
}and I think it should be unrolled statically like the I couldn't find the branch but at some point I had a kernel impl where I calculated the min required index dtype based on the problem size using function constant values. Maybe worth exploring? |
|
Just checked which kernels are actually being run in the benchmark: |
|
Good news and bad news. Good news is that I managed to speed this up even more. Bad news is that I've benchmarked the Good news from the bad news - simplifies the improvements considerably! |
…d implementations Remove large arg reduce kernels as we currently only support writing uint arg reduce results.
|
Just did some quick mafs. |
…k_dim is 1024, so removing 2048 case from reduce kernels
thanks for diving into this! and yea I kinda assumed we could get away with just a u32 but didnt want to make too big of an assumption on todays hardware. however that being said, I do think its reasonable to avoid the u64 for now - and if this become a blocker in the future we can revisit |
There was a problem hiding this comment.
would approve the refactor/updates from @ivarflakstad if this wasnt originally opened by me!
looks great and agree the standardization on u32 is reasonable and realistic
**note just reran benches and get the following values on a MBP M3
| Benchmark | Time | Thrpt |
|---|---|---|
| metal_reduce_f32_strided | 2.3433 ms | 1.6670 GiB/s |
| metal_reduce_f16_strided | 2.3631 ms | 846.33 MiB/s |
| metal_reduce_bf16_strided | 2.3545 ms | 849.44 MiB/s |
…uggingface#3285) * feat: reduce prefers min dtype based on stride size * fix: cargo format fix * Add temporary large/small reduce benchmarks * Improved / simplified strided indexing * Begin untangling pow2 concept from indexer * remove unused get_strided_index_u64 * Make indexer_t handle both cont and strided. indexer.last_dim is a constant * Remove Pow2Meta * Remove redundant contiguous_indexer * Remove redundant strided index impl * Remove pow2 from kernel call/signature * Some size_t -> ulong changes * Use indexer_t for all reduce based kernels * contiguous indexer last_dim default is 0 * u16 indexing does not provide speedup, and is an unlikely use case * Let indexer_t dictate indexing dtype * Introduce finalize concept. Tidy up redundant fns/macros * Tidying up * Use store concept instead of finalize * Simplify arg reduce macros * Tidying up * Add reduce kernels for large tensors. Use existing reduce macro to add implementations Remove large arg reduce kernels as we currently only support writing uint arg reduce results. * Remove u64 indexed reduce kernels (u32 should suffice). Also max block_dim is 1024, so removing 2048 case from reduce kernels * Remove suffix from reduce macro * Remove IDX from reduce macros * Tidy up reduce kernel call code * Explicit u32 in reduce call code. Remove IDX from arg reduce macros * Remove small reduce benchmark --------- Co-authored-by: Ivar Flakstad <[email protected]>
This PR improves the
metal_reduce_.*_stridedkernelsBenchmark
cargo bench --features metal -p candle-core --bench bench_main -- "metal_reduce_.*_strided"Results