Skip to content

Commit

Permalink
Rename unroll_n_to_subgroups to subgroups_n (#19102)
Browse files Browse the repository at this point in the history
"Unroll" usually means "generate more instructions", so the terminology
being changed here, `unroll_n_to_subgroups`, created confusion.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 13, 2024
1 parent e10231c commit cb5d1ab
Show file tree
Hide file tree
Showing 16 changed files with 48 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
//
// That does simplify the below adjustments for narrow M/N, as we don't need
// to think about unroll-to-subgroups when making the narrowing adjustment.
int unrollMToSubgroups = 1;
int unrollNToSubgroups = *wgp.getSimdsPerWgp();
int unrollM = totalUnrollM / unrollMToSubgroups;
int unrollN = totalUnrollN / unrollNToSubgroups;
int subgroupsM = 1;
int subgroupsN = *wgp.getSimdsPerWgp();
int unrollM = totalUnrollM / subgroupsM;
int unrollN = totalUnrollN / subgroupsN;

//
// Step 3: Adjust the unrolling factors when there is a narrow dimension.
Expand All @@ -201,15 +201,14 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
}
if (narrowDim.isN()) {
std::swap(unrollM, unrollN);
std::swap(unrollMToSubgroups, unrollNToSubgroups);
assert(unrollNToSubgroups == 1);
std::swap(subgroupsM, subgroupsN);
assert(subgroupsN == 1);
unrollN = std::min(unrollN, static_cast<int>(llvm::divideCeil(
narrowDim.size, intrinsicMma.getNSize())));
}

return DataTiledMMAAttr::get(ctx, intrinsicMma.getIntrinsic(), unrollM,
unrollMToSubgroups, unrollN, unrollNToSubgroups,
unrollK);
subgroupsM, unrollN, subgroupsN, unrollK);
}

static FailureOr<MaterializeEncodingInfo>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ func.func @matmul_lowering_WMMA_F32_16x16x16_F16() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = WMMA_F32_16x16x16_F16, unroll_m = 4, unroll_n_to_subgroups = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = WMMA_F32_16x16x16_F16, unroll_m = 4, subgroups_n = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ func.func @matmul_lowering_MFMA_i32_16x16x16_i8() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x16_I8, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x16_I8, unroll_m = 4, unroll_n = 2, subgroups_n = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func.func @matmul_lowering_MFMA_f32_16x16x8_bf16() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x8_BF16, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x8_BF16, unroll_m = 4, unroll_n = 2, subgroups_n = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----
Expand Down Expand Up @@ -115,5 +115,5 @@ func.func @matmul_lowering_MFMA_f64_16x16x4_f64() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F64_16x16x4_F64, unroll_m = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F64_16x16x4_F64, unroll_m = 4, subgroups_n = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func.func @matmul_lowering_MFMA_F32_16x16x4_F32() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----
Expand Down Expand Up @@ -426,7 +426,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----
Expand Down Expand Up @@ -622,7 +622,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----
Expand Down Expand Up @@ -700,7 +700,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits
// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, subgroups_n = 4>

// -----

Expand Down Expand Up @@ -773,11 +773,11 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits
// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 4>

// -----

// Custom {simds_per_wgp = 1} => implied default {unroll_n_to_subgroups = 1} (omitted in output) and {unroll_n = 8} instead of {unroll_n_to_subgroups = 4}.
// Custom {simds_per_wgp = 1} => implied default {subgroups_n = 1} (omitted in output) and {unroll_n = 8} instead of {subgroups_n = 4}.

#target_gfx942_except_simds_per_wgp_1 = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
iree.gpu.target = #iree_gpu.target<
Expand Down Expand Up @@ -919,7 +919,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192() at
// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, unroll_n = 2, subgroups_n = 4, unroll_k = 2>

// -----

Expand Down Expand Up @@ -992,7 +992,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096() at
// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, subgroups_n = 4, unroll_k = 2>

// -----

Expand Down Expand Up @@ -1065,7 +1065,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768() a
// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 4, subgroups_n = 4, unroll_k = 2>

// -----

Expand Down Expand Up @@ -1128,7 +1128,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----
Expand Down Expand Up @@ -1188,5 +1188,5 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
if (mma.getUnrollM() > 1) {
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
if (mma.getSubgroupsM() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsM()});
}
break;
case IREE::GPU::MMAFragment::Rhs:
Expand All @@ -169,8 +169,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
if (mma.getUnrollN() > 1) {
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
if (mma.getSubgroupsN() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsN()});
}
break;
case IREE::GPU::MMAFragment::Acc:
Expand All @@ -179,14 +179,14 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
if (mma.getUnrollN() > 1) {
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
expand(swizzle, 1, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
if (mma.getSubgroupsN() > 1) {
expand(swizzle, 1, {Kind::CrossThread, mma.getSubgroupsN()});
}
if (mma.getUnrollM() > 1) {
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
if (mma.getSubgroupsM() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsM()});
}
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,8 @@ std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(),
opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(),
return {opaqueLayout.mSize * getUnrollM() * getSubgroupsM(),
opaqueLayout.nSize * getUnrollN() * getSubgroupsN(),
opaqueLayout.kSize * getUnrollK()};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def IREEGPU_DataTiledMMAAttr :
let parameters = (ins
"::mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr":$intrinsic,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, on the same thread.">:$unroll_m,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$unroll_m_to_subgroups,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$subgroups_m,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, on the same thread.">:$unroll_n,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$unroll_n_to_subgroups,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$subgroups_n,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the K dimension, on the same thread, with interleaved layout.">:$unroll_k
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ module {

module {
func.func @test_data_tiled_mfma_f32_16x16x4_f32() attributes {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2, unroll_k = 1>} {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, subgroups_m = 2, unroll_k = 1>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2>
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, subgroups_m = 2>

module {
func.func @test_data_tiled_mfma_f32_16x16x16_f16() attributes {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n_to_subgroups = 2, unroll_k = 2>} {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, subgroups_n = 2, unroll_k = 2>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x16_f16
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_n_to_subgroups = 2, unroll_k = 2>
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, subgroups_n = 2, unroll_k = 2>

module {
func.func @test_data_tiled_mfma_i32_16x16x32_i8() attributes {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor<?x?x2x4x16x1x4xf32>, %
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, unroll_k = 4>
} : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>
return %0 : tensor<?x?x2x2x4x16x4x1xf32>
}
Expand All @@ -294,7 +294,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor<?x?x2x4x16x1x4xf32>, %
// CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, unroll_k = 4>
// CHECK-SAME: : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
// single subgroup.
const int64_t targetSubgroupSize = dataTiledMmaAttr.getSubgroupSize();
int64_t flatWorkgroupSize = targetSubgroupSize *
dataTiledMmaAttr.getUnrollMToSubgroups() *
dataTiledMmaAttr.getUnrollNToSubgroups();
dataTiledMmaAttr.getSubgroupsM() *
dataTiledMmaAttr.getSubgroupsN();
std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};

// Set all workgroup and reduction tile sizes to 1, since the data tiled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ distributeMultiMmaOp(RewriterBase &rewriter, IREE::GPU::MultiMmaOp mmaOp,
if (auto dataTiledMma = dyn_cast<DataTiledMMAAttr>(newKind)) {
newKind = DataTiledMMAAttr::get(
context, dataTiledMma.getIntrinsic(), dataTiledMma.getUnrollM(),
/*unroll_m_to_subgroups=*/1, dataTiledMma.getUnrollN(),
/*unroll_n_to_subgroups=*/1, dataTiledMma.getUnrollK());
/*subgroups_m=*/1, dataTiledMma.getUnrollN(),
/*subgroups_n=*/1, dataTiledMma.getUnrollK());
}
auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
loc, lhsSlice, rhsSlice, accSlice, mmaOp.getIndexingMaps(),
Expand Down
Loading

0 comments on commit cb5d1ab

Please sign in to comment.