diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir index 489b8a0cb670..db39c0b15742 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir @@ -590,3 +590,96 @@ builtin.module attributes { transform.with_named_sequence } { // CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32> // CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32> // CHECK: return {{.*}} %[[R_SIMD]] + +// ----- + +// Non-native MFMA_F32_32x32x16_F16, i.e CDNA3 V_MFMA_F32_32x32x8_F16 with unrolled_k = 2. +// This non native layout maximizes reads from shared memory to register. + +#map1 = affine_map<(m, n, k) -> (m, k)> +#map2 = affine_map<(m, n, k) -> (k, n)> +#map3 = affine_map<(m, n, k) -> (m, n)> + +// A: shape = 32x16, layout = layoutA +#layout_a = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [1, 1], + outer_tile = [1, 1], + thread_tile = [32, 2], + element_tile = [1, 8], + + subgroup_strides = [1, 1], + thread_strides = [1, 32] +> + +// B: shape = 16x32, layout = layoutB +#layout_b = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [1, 1], + outer_tile = [1, 1], + thread_tile = [2, 32], + element_tile = [8, 1], + + subgroup_strides = [1, 1], + thread_strides = [32, 1] +> + +// C: shape = 32x32, layout = layoutC +#layout_c = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [1, 1], + outer_tile = [4, 1], + thread_tile = [2, 32], + element_tile = [4, 1], + + subgroup_strides = [1, 1], + thread_strides = [32, 1] +> + +func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> { + %A = iree_vector_ext.to_layout %a to layout(#layout_a) : vector<32x16xf16> + %B = iree_vector_ext.to_layout %b to layout(#layout_b) : vector<16x32xf16> + %C = iree_vector_ext.to_layout %c to layout(#layout_c) : vector<32x32xf32> + + %output = vector.contract { + indexing_maps = [#map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind, + iree.amdgpu.mma = #iree_gpu.mma_layout + } %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32> + + %O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32> + return %O : vector<32x32xf32> +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// Notable things to look out for: +// 1. We are reading 8xf16 instead of 4xf16 for lhs,rhs operands. +// 2. We slice the 8xf16 to 2 different 4xf16 per operand for use on 2 MMAs. +// 3. Result of first mma becomes the second mma's accumulator. + +// CHECK-LABEL: func @contract_to_vmfma_32x32x16_mm +// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16> +// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16> +// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32> +// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]] +// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none +// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32> +// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]] +// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none +// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32> +// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_1]] : vector<16xf32> to vector<4x1x4x1xf32> +// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32> +// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32> +// CHECK: return {{.*}} %[[R_SIMD]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 93a2ca762b51..e53f915434aa 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -263,6 +263,14 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::WMMA_I32_16x16x16_I8: { return OpaqueMmaLayout{16, 16, 16, i8, i8, i32}; } + // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved + // along the k dimension. + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: { + return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; + } + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; + } } llvm_unreachable("unhandled mfma layout type"); return OpaqueMmaLayout{}; @@ -412,12 +420,14 @@ MMAAttr::getABCVectorTypes() const { } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); auto cType = VectorType::get({4}, getCType()); return std::make_tuple(aType, bType, cType); } + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); @@ -461,7 +471,9 @@ int64_t MMAAttr::getBlockSize() const { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: case MMAIntrinsic::WMMA_F16_16x16x16_F16: case MMAIntrinsic::WMMA_F32_16x16x16_F16: @@ -484,7 +496,9 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { return 64; } @@ -549,6 +563,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, /*element=*/{4, 1}}; } + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: @@ -563,6 +578,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: switch (fragment) { case MMAFragment::Lhs: @@ -616,6 +632,19 @@ MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const { return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc); } +// Get virtual intrinsics that is composed/based on queried op. +SmallVector MMAAttr::getVirtualIntrinsics() const { + switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x16_F16: + return {MMAIntrinsic::VMFMA_F32_16x16x32_F16}; + case MMAIntrinsic::MFMA_F32_32x32x8_F16: + return {MMAIntrinsic::VMFMA_F32_32x32x16_F16}; + default: + return {}; + } + return {}; +} + // Generates amdgpu.mfma/wmma operation on the given inputs for this attribute // type. FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, @@ -643,6 +672,37 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, rhs, acc) .getResult(); } + case MMAIntrinsic::VMFMA_F32_16x16x32_F16: + case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { + // Generate mfma's for K with unrolled kernels. + const int64_t unrollKFactor = 2; + auto [m, n, k] = getMNKShape(); + // Compute actual/native intrinsic's K size. + int64_t nativeKSize = k / unrollKFactor; + + auto [aType, bType, cType] = getABCVectorTypes(); + if (aType.getShape()[0] != bType.getShape()[0]) { + // Currently only support case where lhs and rhs + // has same vectorWidth. + return failure(); + } + int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; + for (int i = 0; i < unrollKFactor; i++) { + int64_t offset = vectorWidth * i; + Value sliced_lhs = builder.create( + loc, lhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + Value sliced_rhs = builder.create( + loc, rhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + acc = builder + .create(loc, resultType, m, n, nativeKSize, + getBlockSize(), sliced_lhs, sliced_rhs, + acc) + .getResult(); + } + return acc; + } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: case MMAIntrinsic::MFMA_F32_16x16x16_BF16: diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index d04e9fefe5b9..bbb79628e1d3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -216,6 +216,8 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> { MMASingleSubgroupLayout getASingleSubgroupLayout() const; MMASingleSubgroupLayout getBSingleSubgroupLayout() const; MMASingleSubgroupLayout getCSingleSubgroupLayout() const; + + SmallVector getVirtualIntrinsics() const; }]; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index 9d4ac2e9a4e1..1afdf0d235be 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -98,7 +98,13 @@ class IREEGPU_I32MmaEnumAttr let genSpecializedAttr = 0; } -// Format: __xx_ +// Format: __xx_ +// +// "virtual": Prefixes intrinsic with "V" to represent Non native-MFMA +// emulating a larger MMA with smaller ones. This is useful +// to interleave reads in K-dim, S.T we can have wider reads +// or align layouts between matmuls. +// // Values: 0xABCD where: // * A = vendor: // * 0 = AMD @@ -121,6 +127,8 @@ class IREEGPU_I32MmaEnumAttr def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>; def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>; def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>; +def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0x0912>; +def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 0x0913>; def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>; def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>; def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>; @@ -145,6 +153,8 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x4_F32, MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, + VMFMA_F32_16x16x32_F16, + VMFMA_F32_32x32x16_F16, MFMA_F32_16x16x16_BF16, MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E4M3FNUZ, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ede2d0bcf7b8..b4567e32938d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -309,15 +309,32 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); + SmallVector mmaAttrs; + MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); if (mma.getSubgroupSize() != targetSubgroupSize) continue; - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + storeMmaInfo(mma, intrinsics, mmaAttrs); + // Store info on virtual intrinsics based on current mma if any + for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + mma.getVirtualIntrinsics()) { + auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + } } + if (intrinsics.empty()) return failure(); @@ -379,7 +396,6 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, reductionTileSizes[filterDim] = 1; } - MLIRContext *context = op.getContext(); Builder b(context); SmallVector attrs; attrs.emplace_back(StringAttr::get(context, "workgroup"), @@ -395,8 +411,8 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], - schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); + context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0], + schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -506,15 +522,32 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); + SmallVector mmaAttrs; + MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); if (mma.getSubgroupSize() != targetSubgroupSize) continue; - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + storeMmaInfo(mma, intrinsics, mmaAttrs); + // Store info on virtual intrinsics based on current mma if any + for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + mma.getVirtualIntrinsics()) { + auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + } } + if (intrinsics.empty()) return failure(); @@ -627,7 +660,6 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, LLVM_DEBUG(debugPrintContractionInfo("Reduction tile sizes", op.getNumLoops(), *contractionDims, reductionTileSizes)); - MLIRContext *context = op.getContext(); Builder b(context); SmallVector attrs; attrs.emplace_back(StringAttr::get(context, "workgroup"), @@ -643,8 +675,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], - schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); + context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0], + schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -709,15 +741,32 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, Value kMatrix = op.getKey(); Value vMatrix = op.getValue(); + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); + SmallVector mmaAttrs; + MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); if (mma.getSubgroupSize() != targetSubgroupSize) continue; - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + storeMmaInfo(mma, intrinsics, mmaAttrs); + // Store info on virtual intrinsics based on current mma if any + for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + mma.getVirtualIntrinsics()) { + auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + } } + if (intrinsics.empty()) return failure(); @@ -826,7 +875,6 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize; - MLIRContext *context = op.getContext(); SmallVector attrs; attrs.emplace_back(StringAttr::get(context, "workgroup"), b.getI64ArrayAttr(workgroupTileSizes)); @@ -878,8 +926,8 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], - schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); + context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0], + schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 7e1ab62101b3..cedec2d21f2f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -655,6 +655,69 @@ hal.executable public @contract_schedule_considering_read_layout { // ----- +// This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs. + +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}> +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @virtual_intrinsic_256x256x256_f16_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @virtual_intrinsic_256x256x256_f16_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @virtual_intrinsic_256x256x256_f16_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f16_f32 +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x4x1x4x1xf32>) + +// Validate that VMFMA is decomposed into coalesced read and 2 MFMAs: + +// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16> +// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16> +// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32> +// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]] +// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none +// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32> +// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]] +// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none +// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32> + +// Ensure right number of instructions are being generated. + +// CHECK-COUNT-14: vector.extract_strided_slice +// CHECK-NEXT: amdgpu.mfma +// CHECK: scf.yield + +// ----- + #config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}> #translation = #iree_codegen.translation_info, subgroup_m_count = 2, subgroup_n_count = 1>}> diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index cd6f8ebea6d3..dd387f31141f 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -347,6 +347,10 @@ def get_rocm_test_compilation_infos( MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2), MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2), MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2), + MMASchedule("VMFMA_F32_16x16x32_F16", 1, 1, 1, 1, 1), + MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4), + MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1), + MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4), ] elif intrinsic == "WMMA": schedules = [ @@ -393,13 +397,17 @@ def get_rocm_test_compilation_infos( wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 wg_tile_k = schedule.k_tile_count * 8 elif ( - schedule.intrinsic == "MFMA_I32_16x16x32_I8" + schedule.intrinsic == "VMFMA_F32_16x16x32_F16" + or schedule.intrinsic == "MFMA_I32_16x16x32_I8" or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ" ): wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 wg_tile_k = schedule.k_tile_count * 32 - elif schedule.intrinsic == "MFMA_I32_32x32x16_I8": + elif ( + schedule.intrinsic == "VMFMA_F32_32x32x16_F16" + or schedule.intrinsic == "MFMA_I32_32x32x16_I8" + ): wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 wg_tile_k = schedule.k_tile_count * 16