-
Notifications
You must be signed in to change notification settings - Fork 258
Grouped conv bwd weight with grouped gemm #2304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Outdated
Show resolved
Hide resolved
| group_id = index_t((left + right) / 2); | ||
| } | ||
|
|
||
| if(gemm_kernel_args[group_id].HasMainKBlockLoop_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will likely result in scratch... It would be better to keep HasMainKBlockLoop tparam. Can we make such assumption that all groups have the same HasMainKBlockLoop value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HasMainKBlockLoop can be different for each gemm so we cant assume that. I dont see any scratch with such implementation
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Outdated
Show resolved
Hide resolved
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Outdated
Show resolved
Hide resolved
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Outdated
Show resolved
Hide resolved
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Show resolved
Hide resolved
...tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Outdated
Show resolved
Hide resolved
| this->conv_params.push_back( | ||
| {2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); | ||
| this->conv_params.push_back( | ||
| {2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}}); | ||
| this->conv_params.push_back( | ||
| {2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}}); | ||
| this->conv_params.push_back( | ||
| {2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How many groups per each case would you test here? Do you cover as well case with just one group? And a case where you exceed the MaxKernelArgsNum ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It covers 4, 9 (but with ZTilde *YTilde * XTilde = 36), 36, 32. So each case. I will add a description
aosewski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely :D
Proposed changes
Change multi kernel launch to grouped gemm to avoid launching a lot of kernels with one stream. It could cause the situation that a lot of workgroups are idle.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered