-
Notifications
You must be signed in to change notification settings - Fork 269
Congma/ck tile/preshuffle b #3645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
080fa14
dc83e28
109bfa1
bc91bb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -40,20 +40,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy | |||||||||||||||||||||||||||||||
| CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| using TileShape = typename Problem::BlockGemmShape; | ||||||||||||||||||||||||||||||||
| #if defined(__gfx11__) | ||||||||||||||||||||||||||||||||
| constexpr index_t scale = 4; | ||||||||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||||||||
| constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; | ||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||
| if constexpr(TileShape::WarpTile::at(I1) == 32) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| return TileShape::WarpTile::at(I2) * scale / 2; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| static_assert(TileShape::WarpTile::at(I1) == 16); | ||||||||||||||||||||||||||||||||
| return TileShape::WarpTile::at(I2) * scale / 4; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| constexpr index_t k_b_per_load = | ||||||||||||||||||||||||||||||||
| TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return k_b_per_load; | ||||||||||||||||||||||||||||||||
|
Comment on lines
+44
to
+47
|
||||||||||||||||||||||||||||||||
| constexpr index_t k_b_per_load = | |
| TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); | |
| return k_b_per_load; | |
| constexpr index_t base_k_b_per_load = | |
| TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); | |
| #if defined(__gfx11__) | |
| // On gfx11, MakeBFlatDramTileDistribution() uses KRepeatInWave = 2 and asserts | |
| // TileShape::flatKPerWarp == KThdPerWave * KBPerLoad. To keep this invariant valid, | |
| // fold KRepeatInWave into KBPerLoad here. | |
| return base_k_b_per_load * 2; | |
| #else | |
| return base_k_b_per_load; | |
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| #include "ck_tile/ops/gemm.hpp" | ||
| #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" | ||
| #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" | ||
| #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" | ||
|
|
||
|
Comment on lines
11
to
15
|
||
| using AddScale = ck_tile::element_wise::AddScale; | ||
| using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; | ||
|
|
@@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_) | |
| ck_tile::tensor_layout::gemm::RowMajor>>{}; | ||
| } | ||
|
|
||
| template <typename PrecType, ck_tile::index_t M_Warp_Tile> | ||
| constexpr ck_tile::index_t get_k_warp_tile() | ||
| { | ||
| #if CK_TILE_USE_WMMA | ||
| return 16; | ||
| #else | ||
| #if defined(CK_GFX950_SUPPORT) | ||
| constexpr bool is_8bit_float = | ||
| std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>; | ||
| if constexpr(M_Warp_Tile == 32) | ||
| return is_8bit_float ? 64 : 16; | ||
| else | ||
| return is_8bit_float ? 128 : 32; | ||
| #else | ||
| if constexpr(M_Warp_Tile == 32) | ||
| return 16; | ||
| else | ||
| return 32; | ||
| #endif | ||
| #endif | ||
| } | ||
|
|
||
| template <typename A0DataType, | ||
| typename B0DataType, | ||
| typename AccDataType, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test | |||||
| static const ck_tile::index_t M_Warp_Tile = 16; | ||||||
| static const ck_tile::index_t N_Warp_Tile = 16; | ||||||
| static const ck_tile::index_t K_Warp_Tile = | ||||||
| ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>(); | ||||||
| ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>(); | ||||||
|
||||||
| ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>(); | |
| ck_tile::get_k_warp_tile_for_preshuffle_b<BDataType, N_Warp_Tile>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_k_warp_tile_for_preshuffle_bcomputeskKPerWarpasintand then callsmax(kKPerWarp, kMfmaMaxK). In this codebaseck_tile::max(T,T)requires both arguments to be the same type; mixingintandindex_twill fail to compile (template recursion ends up with no viablemax(int, index_t)overload). Make the intermediate constantsconstexpr index_t(or castkKPerWarptoindex_t) so the finalmaxcall is between identical types.