-
Notifications
You must be signed in to change notification settings - Fork 269
Implement group merging for bwd_weight #3637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
vpietila-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. @bartekxk what do you think?
964372a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements group merging functionality for the backward weight V3 kernel in grouped convolutions. The feature enables multiple groups to be merged and processed together, improving performance for workloads that result in skinny GEMM operations.
Changes:
- Added
NumGroupsToMergetemplate parameter to the V3 kernel implementation with default value of 1 - Updated batch stride calculations to account for merged groups
- Added four new kernel instances with group merging factors of 2 and 4
- Added validation to ensure the number of groups is evenly divisible by the merge factor
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| device_grouped_conv_bwd_weight_v3_xdl_instance.hpp | Adds four new kernel instances with group merging enabled (merge factors of 2 and 4) |
| device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | Implements the core group merging logic including template parameter, stride adjustments, and validation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if constexpr(NumGroupsToMerge > 1) | ||
| { | ||
| if(arg.Conv_G_ % NumGroupsToMerge != 0) | ||
| { | ||
| if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) | ||
| { | ||
| std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_=" | ||
| << arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge | ||
| << std::endl; | ||
| } | ||
| return false; | ||
| } | ||
| } |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message describes an internal condition check rather than a user-facing error. Consider rephrasing to explain the constraint in user terms, such as 'Number of groups must be evenly divisible by the merge factor' and include both values for clarity.
| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | ||
| // generic instance | ||
| DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, | ||
| DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 96 is a duplicate of line 95 except for the last three template parameters (F16, F16, 2). This creates two nearly identical kernel instances, which could lead to confusion. Consider adding a comment explaining why both configurations are needed or if this duplication is intentional.
| DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, | |
| DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, | |
| // Note: this variant is intentionally similar to the previous line, but overrides | |
| // the final template parameters of DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 | |
| // (here: F16, F16, 2) to select a distinct kernel implementation. |
Proposed changes
Adds support for merging groups in bwd weight V3 kernel. Four such kernel instances are also added. For shapes that result in very skinny GEMMs, this leads to a performance improvement. Examples of such shapes and their uplift on MI350X is shown below.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion