diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 6620b050f80d..e011b78194b0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -306,8 +306,9 @@ static std::tuple getABCElementTypes(MLIRContext *context, return {}; } -static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, - MMAIntrinsic intrinsic) { +template +static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + MMAIntrinsicType intrinsic) { OpaqueMmaLayout o; std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); @@ -369,9 +370,9 @@ getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) { PerDimLayoutAttr::get(context, labels[1], shape[1])}; }; -static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, - MMAIntrinsic intrinsic) { - auto opaque = getOpaqueMFMALayout(context, intrinsic); +static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context, + MMAIntrinsic intrinsic) { + auto opaque = getOpaqueMMALayout(context, intrinsic); ConcreteMmaLayout concreteLayout; concreteLayout.base = opaque; auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs); @@ -452,7 +453,7 @@ void MMAAttr::print(AsmPrinter &p) const { } MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) { - auto layout = getOpaqueMFMALayout(context, type); + auto layout = getOpaqueMMALayout(context, type); return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize, layout.nSize, layout.kSize, layout.aType, layout.bType, layout.cType); @@ -466,9 +467,11 @@ std::tuple MMAAttr::getMNKShape() const { return {getMSize(), getNSize(), getKSize()}; } -static VectorType getVectorType(MLIRContext *context, MMAIntrinsic intrinsic, +template +static VectorType getVectorType(MLIRContext *context, + MMAIntrinsicType intrinsic, MMAFragment fragment) { - auto o = getOpaqueMFMALayout(context, intrinsic); + auto o = getOpaqueMMALayout(context, intrinsic); auto s = getSingleSubgroupLayout(intrinsic, fragment); Type elemType = (fragment == MMAFragment::Lhs) ? o.aType : (fragment == MMAFragment::Rhs) ? o.bType @@ -491,7 +494,7 @@ FailureOr> MMAAttr::getContractionLayout(vector::ContractionOp contract) const { ConcreteMmaLayout layout = - getConcreteMFMALayout(contract->getContext(), getIntrinsic().getValue()); + getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue()); return IREE::GPU::getContractionLayout(contract, layout); } @@ -932,13 +935,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle, std::tuple DataTiledMMAAttr::getABCElementTypes() const { MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); + auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; } std::tuple DataTiledMMAAttr::getMNKShape() const { MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); + auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); return {opaqueLayout.mSize * getUnrollM() * getSubgroupsM(), opaqueLayout.nSize * getUnrollN() * getSubgroupsN(), opaqueLayout.kSize * getUnrollK()}; @@ -1228,68 +1231,47 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context, return VirtualMMAAttr::get(context, intrinsicAttr); } -static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context, - VirtualMMAIntrinsic type) { +static std::tuple +getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) { Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); Type f16 = Float16Type::get(context); Type f32 = Float32Type::get(context); switch (type) { - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { - return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; - } - case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: { - return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32}; - } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: + return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved // along the k dimension. - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { - return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; - } - case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { - return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; - } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + return {f16, f16, f32}; + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: + return {f16, f16, f32}; } assert(false && "unhandled virtual mma layout type."); - return OpaqueMmaLayout{}; + return {}; } std::tuple VirtualMMAAttr::getABCElementTypes() const { MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; } std::tuple VirtualMMAAttr::getABCVectorTypes() const { - // Check https://github.com/ROCm/amd_matrix_instruction_calculator for - // instruction details. Note here we are returning the number elements, while - // amd_matrix_instruction_calculator tells us about the number of 32-bit - // registers. So need to adjust accordingly. All vectors should be 1-D. - auto [A, B, C] = getABCElementTypes(); - switch (getIntrinsic().getValue()) { - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { - auto aType = VectorType::get({8}, A); - auto bType = VectorType::get({8}, B); - auto cType = VectorType::get({4}, C); - return {aType, bType, cType}; - } - case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: - case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { - auto aType = VectorType::get({8}, A); - auto bType = VectorType::get({8}, B); - auto cType = VectorType::get({16}, C); - return {aType, bType, cType}; - } - } - assert(false && "unhandled virtual mma layout type."); - return {VectorType{}, VectorType{}, VectorType{}}; + MLIRContext *context = getContext(); + VirtualMMAIntrinsic intrinsic = getIntrinsic().getValue(); + VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs); + VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs); + VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc); + return {aVecType, bVecType, cVecType}; } std::tuple VirtualMMAAttr::getMNKShape() const { MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize}; }