-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: reduce prefers min dtype based on stride size #3285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Oh yeah big mr huggingface employee coming here showing the people how to vibe code kernel code with proper results, eh? But seriously this is really good. We've previously talked about how indexing with smaller dtypes is surprisingly efficient, and I see you've also come up with a similar strided index improvement like I did here. if (is_pow2[k]) {
r = idx & masks[k];
idx = idx >> shifts[k];
}and I think it should be unrolled statically like the I couldn't find the branch but at some point I had a kernel impl where I calculated the min required index dtype based on the problem size using function constant values. Maybe worth exploring? |
| encoder.dispatch_thread_groups( | ||
| MTLSize { | ||
| width: out_length, | ||
| height: 1, | ||
| depth: 1, | ||
| }, | ||
| MTLSize { | ||
| width, | ||
| height: 1, | ||
| depth: 1, | ||
| }, | ||
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is functionally the exact same as before, just without named variables.
Let's keep to a minimal amount of changes in this PR.
| struct Pow2Meta { | ||
| is_pow2: Vec<u8>, | ||
| masks: Vec<u32>, | ||
| shifts: Vec<u8>, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required, just fyi you could slap a #[repr(C)] on this and create it's equal in metal and we can send in the entire guy instead of it's parts.
| let pow2 = dim.is_power_of_two() && dim > 1; | ||
| is_pow2.push(pow2 as u8); | ||
| if pow2 { | ||
| masks.push((dim - 1) as u32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any reason this couldn't be u16 or u64.
If we should store it as u64 in this calculation and convert to the appropriate dtype on kernel launch we could cover the other cases as well.
This does conflict a bit with the #[repr(C)] comment above.
| let max_dim = shape.iter().copied().max().unwrap_or(0); | ||
| let max_stride = strides.iter().copied().max().unwrap_or(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this is interesting. I was certain you would select the index type based on the maximum index value, aka shape.iter().product().
But looking at the metal fn you're still calculating the actual index using u64, but the intermediaries are in the smaller index type. Which has me wondering if this part actually improves performance, or if it is all the "pow2" feature.
Also nit: The stride max will always be larger than shape max, so that's the only one you have to calculate.
I have a weird idea that could allow us to use u16 for all indexing (even if the elem count of the tensor is way larger). Fallback to u32, but probably wouldn't need u64 (given that the max stride < 4 294 967 295 hehe).
Totally different strided indexing approach. max_stride remains as the limiting factor.
Instead of calculating the final index and then indexing we incr/decr the actual data pointers for each dim. I think it could work, but it sounds messy so not sure we want to go down that path 🤔
Here's the simplest variant. It adjusts the ptr, returns nothing.
template<typename T, typename IndexT = ushort>
METAL_FUNC void apply_stride_ptr(
device const T *ptr,
IndexT idx,
constant const IndexT &num_dims,
constant const IndexT *dims,
constant const IndexT *strides
) {
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
ptr += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
}Haven't tested. Probably doesn't with the const attribute so some adjustment needed.
A bit tedious if you want to use the same indexing for multiple data sources.
But it would be easy to use pow2 for all IndexT, and simplifies the kernel suffix.
| METAL_FUNC uint get_strided_index_1d( | ||
| uint idx, | ||
| constant const uint *dims, | ||
| constant const uint *strides | ||
| ) { | ||
| return idx * strides[0]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in my PR where I do the same I'd prefer to wait with this until we have precompilation for the metal backend. I'll add it.
I also like my approach a tiny bit better than claude's hehe.
| return strided_i; | ||
| } | ||
|
|
||
| METAL_FUNC uint get_strided_index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking closely this isn't actually used.
You either use indexer_pow2 or indexer_u64.
Which btw could be just indexer using templating.
| constant uint *masks, | ||
| constant uchar *shifts | ||
| ) { | ||
| ulong strided_i = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ulong, unsigned long, uint64_t and size_t is the same. I'm happy using ulong, just thinking about consistency with for example the function directly below this one.
| encoder.dispatch_thread_groups( | ||
| MTLSize { | ||
| width: out_length, | ||
| height: 1, | ||
| depth: 1, | ||
| }, | ||
| MTLSize { | ||
| width, | ||
| height: 1, | ||
| depth: 1, | ||
| }, | ||
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
|
Just checked which kernels are actually being run in the benchmark: |
|
Good news and bad news. Good news is that I managed to speed this up even more. Bad news is that I've benchmarked the Good news from the bad news - simplifies the improvements considerably! |
This PR improves the
metal_reduce_.*_stridedkernels by using a smaller dtype when possible (when the strides fit within a smaller type than au64)benched
cargo bench --features metal -p candle-core --bench bench_main -- "metal_reduce_.*_strided"summary
The optimizations provide a ~1.7-1.8x speedup on strided reduce operations, this comes from:
Note
this PR is a collaboration of myself and good ol' Claude Code. Claude was first used to profile the reduce kernel and noted the performance degradation on strided reduce and the expensive u64 division. Then used to iterate on some changes and ultimately we settled on using a smaller dtype where possible and prefer cheaper operations in the case the dimension are powers of 2.
Please let me know if there are any questions about the implementation! more than happy to make any changes or close the PR if it introduces more complexity than justified for the performance gains!