-
Notifications
You must be signed in to change notification settings - Fork 258
[CK_TILE] Remove scratch usage from universal gemm #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| return ave_time; | ||
| }; | ||
|
|
||
| const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor this to use previous method of "run"? We can create two GemmEpilogue:
- GemmEpilogue
- GemmEpilogueSplitK
Then create GemmKernel and GemmKernelSplitK and launch appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still you would have if k_batch ==1 .... else logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue< which you would parameterize with TransposeC - because it's known only in lambda, from UniversalGEmmProblem and with memory_operation. All other types are known earlier.
| { | ||
| Run_with_k_batch(has_hot_loop_, | ||
| tail_number_, | ||
| ck_tile::integral_constant<ck_tile::memory_operation_enum, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same for each operator
| static constexpr index_t kMPerIteration = kMPerXdl * kMWave; | ||
| static constexpr index_t kNPerIteration = kNPerXdl * kNWave; | ||
| using CLayout = remove_cvref_t<typename Problem::CLayout>; | ||
| static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we put this information as a class member, you no longer need to have it as a operator() template paramter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still see this as a tparam here:
https://github.com/ROCm/composable_kernel/pull/2001/files#diff-a2466bfef61d871813a8a210b406e7db886cc294935562fe0076017c0e7f75aaR127
| return ave_time; | ||
| }; | ||
|
|
||
| const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still you would have if k_batch ==1 .... else logic inside. For me it's OK to do this in two stages. I'd rather make an alias template of using GemmEpilogue = ck_tile::CShuffleEpilogue< which you would parameterize with TransposeC - because it's known only in lambda, from UniversalGEmmProblem and with memory_operation. All other types are known earlier.
952699e to
1bec2e5
Compare
aosewski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there ;)
| static constexpr index_t kMPerIteration = kMPerXdl * kMWave; | ||
| static constexpr index_t kNPerIteration = kNPerXdl * kNWave; | ||
| using CLayout = remove_cvref_t<typename Problem::CLayout>; | ||
| static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still see this as a tparam here:
https://github.com/ROCm/composable_kernel/pull/2001/files#diff-a2466bfef61d871813a8a210b406e7db886cc294935562fe0076017c0e7f75aaR127
Proposed changes
Remove scratch usage from universal gemm by moving the if kbatch related condition oustide of kernel and passing memory operation enum as a template parameter
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed files