Skip to content

Conversation

@drbh
Copy link
Contributor

@drbh drbh commented Jan 7, 2026

This PR improves the metal_reduce_.*_strided kernels by using a smaller dtype when possible (when the strides fit within a smaller type than a u64)

benched

cargo bench --features metal -p candle-core --bench bench_main -- "metal_reduce_.*_strided"

summary

| Benchmark    | main               | feature branch      | Speedup      |
|--------------|--------------------|---------------------|--------------|
| f32_strided  | 5.62ms (711 MiB/s) | 3.12ms (1.25 GiB/s) | 1.80x faster |
| f16_strided  | 3.98ms (503 MiB/s) | 2.35ms (851 MiB/s)  | 1.69x faster |
| bf16_strided | 3.95ms (505 MiB/s) | 2.35ms (849 MiB/s)  | 1.68x faster |

The optimizations provide a ~1.7-1.8x speedup on strided reduce operations, this comes from:

  1. Smaller index types (u16/u32 vs u64) - faster division operations
  2. Power-of-two optimization - bitwise & and >> instead of % and / for pow2 dimensions

Note

this PR is a collaboration of myself and good ol' Claude Code. Claude was first used to profile the reduce kernel and noted the performance degradation on strided reduce and the expensive u64 division. Then used to iterate on some changes and ultimately we settled on using a smaller dtype where possible and prefer cheaper operations in the case the dimension are powers of 2.

Please let me know if there are any questions about the implementation! more than happy to make any changes or close the PR if it introduces more complexity than justified for the performance gains!

@drbh drbh requested a review from ivarflakstad January 7, 2026 00:00
@ivarflakstad
Copy link
Member

ivarflakstad commented Jan 7, 2026

Oh yeah big mr huggingface employee coming here showing the people how to vibe code kernel code with proper results, eh?

But seriously this is really good. We've previously talked about how indexing with smaller dtypes is surprisingly efficient, and I see you've also come up with a similar strided index improvement like I did here.
I find this part really interesting (and I suspect is a substantial part of the performance):

if (is_pow2[k]) {
  r = idx & masks[k];
  idx = idx >> shifts[k];
}

and I think it should be unrolled statically like the get_strided_index_1d / get_strided_idx<1> variants.
(edit: at least if all are pow2)

I couldn't find the branch but at some point I had a kernel impl where I calculated the min required index dtype based on the problem size using function constant values. Maybe worth exploring?

Comment on lines +105 to +116
encoder.dispatch_thread_groups(
MTLSize {
width: out_length,
height: 1,
depth: 1,
},
MTLSize {
width,
height: 1,
depth: 1,
},
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is functionally the exact same as before, just without named variables.
Let's keep to a minimal amount of changes in this PR.

Comment on lines 35 to 39
struct Pow2Meta {
is_pow2: Vec<u8>,
masks: Vec<u32>,
shifts: Vec<u8>,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, just fyi you could slap a #[repr(C)] on this and create it's equal in metal and we can send in the entire guy instead of it's parts.

let pow2 = dim.is_power_of_two() && dim > 1;
is_pow2.push(pow2 as u8);
if pow2 {
masks.push((dim - 1) as u32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason this couldn't be u16 or u64.
If we should store it as u64 in this calculation and convert to the appropriate dtype on kernel launch we could cover the other cases as well.
This does conflict a bit with the #[repr(C)] comment above.

Comment on lines +15 to +16
let max_dim = shape.iter().copied().max().unwrap_or(0);
let max_stride = strides.iter().copied().max().unwrap_or(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this is interesting. I was certain you would select the index type based on the maximum index value, aka shape.iter().product().

But looking at the metal fn you're still calculating the actual index using u64, but the intermediaries are in the smaller index type. Which has me wondering if this part actually improves performance, or if it is all the "pow2" feature.

Also nit: The stride max will always be larger than shape max, so that's the only one you have to calculate.

I have a weird idea that could allow us to use u16 for all indexing (even if the elem count of the tensor is way larger). Fallback to u32, but probably wouldn't need u64 (given that the max stride < 4 294 967 295 hehe).
Totally different strided indexing approach. max_stride remains as the limiting factor.
Instead of calculating the final index and then indexing we incr/decr the actual data pointers for each dim. I think it could work, but it sounds messy so not sure we want to go down that path 🤔

Here's the simplest variant. It adjusts the ptr, returns nothing.

template<typename T, typename IndexT = ushort>
METAL_FUNC void apply_stride_ptr(
    device const T *ptr,
    IndexT idx,
    constant const IndexT &num_dims,
    constant const IndexT *dims,
    constant const IndexT *strides
) {
    for (uint d = 0; d < num_dims; d++) {
        uint dim_idx = num_dims - 1 - d;
        ptr += (idx % dims[dim_idx]) * strides[dim_idx];
        idx /= dims[dim_idx];
    }
}

Haven't tested. Probably doesn't with the const attribute so some adjustment needed.

A bit tedious if you want to use the same indexing for multiple data sources.

But it would be easy to use pow2 for all IndexT, and simplifies the kernel suffix.

Comment on lines 49 to 55
METAL_FUNC uint get_strided_index_1d(
uint idx,
constant const uint *dims,
constant const uint *strides
) {
return idx * strides[0];
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in my PR where I do the same I'd prefer to wait with this until we have precompilation for the metal backend. I'll add it.
I also like my approach a tiny bit better than claude's hehe.

return strided_i;
}

METAL_FUNC uint get_strided_index(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking closely this isn't actually used.
You either use indexer_pow2 or indexer_u64.
Which btw could be just indexer using templating.

constant uint *masks,
constant uchar *shifts
) {
ulong strided_i = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ulong, unsigned long, uint64_t and size_t is the same. I'm happy using ulong, just thinking about consistency with for example the function directly below this one.

Comment on lines +207 to +218
encoder.dispatch_thread_groups(
MTLSize {
width: out_length,
height: 1,
depth: 1,
},
MTLSize {
width,
height: 1,
depth: 1,
},
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@ivarflakstad
Copy link
Member

ivarflakstad commented Jan 7, 2026

Just checked which kernels are actually being run in the benchmark: fast_sum_f32_strided, fast_sum_f16_strided, and fast_sum_bf16_strided. So we've only hit the IndexType::U32 => "" case and only profiled uint indexing.

@ivarflakstad
Copy link
Member

Good news and bad news.

Good news is that I managed to speed this up even more.

Bad news is that I've benchmarked the pow2 concept in isolation and there is no measurable performance benefit.
I'm assuming it is because using 3 lookups (is_pow2[k], masks[k], shifts[k]) counteracts the benefit of mask and shift operations. I even removed the is_pow2[k] check and there was still no measurable difference.
It seems all the gains are from using u32 and static unrolling.

Good news from the bad news - simplifies the improvements considerably!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants