diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index e011b78194b0..0a274373ce8a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -220,88 +220,45 @@ static std::tuple getABCElementTypes(MLIRContext *context, Type i8 = IntegerType::get(context, 8); Type i32 = IntegerType::get(context, 32); switch (intrinsic) { - case MMAIntrinsic::MFMA_F64_16x16x4_F64: { + case MMAIntrinsic::MFMA_F64_16x16x4_F64: return {f64, f64, f64}; - } - case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: return {f32, f32, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x16_F16: { - return {f16, f16, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x8_F16: { + case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::WMMA_F32_16x16x16_F16: + case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: return {f16, f16, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x8_BF16: { - return {bf16, bf16, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x4_BF16: { - return {bf16, bf16, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { - return {bf16, bf16, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { + case MMAIntrinsic::WMMA_F16_16x16x16_F16: + case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: + return {f16, f16, f16}; + case MMAIntrinsic::MFMA_F32_16x16x8_BF16: + case MMAIntrinsic::MFMA_F32_32x32x4_BF16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: + case MMAIntrinsic::WMMA_F32_16x16x16_BF16: return {bf16, bf16, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { - return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { - return {f8E5M2FNUZ, f8E5M2FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: { - return {f8E4M3FNUZ, f8E5M2FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: { - return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: { + case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: + return {bf16, bf16, bf16}; + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: { + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: return {f8E5M2FNUZ, f8E5M2FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: { + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: return {f8E4M3FNUZ, f8E5M2FNUZ, f32}; - } - case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: { + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; - } - case MMAIntrinsic::MFMA_I32_16x16x32_I8: { - return {i8, i8, i32}; - } - case MMAIntrinsic::MFMA_I32_32x32x16_I8: { - return {i8, i8, i32}; - } - case MMAIntrinsic::MFMA_I32_32x32x8_I8: { - return {i8, i8, i32}; - } - case MMAIntrinsic::MFMA_I32_16x16x16_I8: { - return {i8, i8, i32}; - } - case MMAIntrinsic::WMMA_F32_16x16x16_F16: { - return {f16, f16, f32}; - } - case MMAIntrinsic::WMMA_F16_16x16x16_F16: { - return {f16, f16, f16}; - } - case MMAIntrinsic::WMMA_F32_16x16x16_BF16: { - return {bf16, bf16, f32}; - } - case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: { - return {bf16, bf16, bf16}; - } - case MMAIntrinsic::WMMA_I32_16x16x16_I8: { + case MMAIntrinsic::MFMA_I32_16x16x16_I8: + case MMAIntrinsic::MFMA_I32_32x32x8_I8: + case MMAIntrinsic::MFMA_I32_16x16x32_I8: + case MMAIntrinsic::MFMA_I32_32x32x16_I8: + case MMAIntrinsic::WMMA_I32_16x16x16_I8: return {i8, i8, i32}; } - case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: { - return {f16, f16, f16}; - } - case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: { - return {f16, f16, f32}; - } - } assert(false && "unexpected enum value"); return {}; } @@ -498,11 +455,15 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const { return IREE::GPU::getContractionLayout(contract, layout); } -int64_t MMAAttr::getBlockSize() const { +static int getBlockSize(MMAIntrinsic /*intrinsic*/) { // Not supporting any block size other than 1 at the moment. return 1; } +int64_t MMAAttr::getBlockSize() const { + return IREE::GPU::getBlockSize(getIntrinsic().getValue()); +} + static uint32_t getArchID(MMAIntrinsic intrinsic) { return static_cast(intrinsic) & 0xFF00; } @@ -704,6 +665,31 @@ SmallVector MMAAttr::getVirtualIntrinsics() const { } } +static Value createMmaOp(OpBuilder &builder, Location loc, + MMAIntrinsic intrinsic, Type resultType, Value lhs, + Value rhs, Value acc) { + auto getVecOrSingleElem = [&](Value vec) -> Value { + bool one = llvm::cast(vec.getType()).getNumElements() == 1; + return one ? builder.create(loc, vec, 0) : vec; + }; + auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic); + if (is_AMD_MFMA(intrinsic)) { + // MFMA intrinsics want single-element operands of element type, not vector. + lhs = getVecOrSingleElem(lhs); + rhs = getVecOrSingleElem(rhs); + return builder + .create(loc, resultType, layout.mSize, layout.nSize, + layout.kSize, getBlockSize(intrinsic), lhs, rhs, + acc) + .getResult(); + } + if (is_AMD_WMMA(intrinsic)) { + return builder.create(loc, resultType, lhs, rhs, acc) + .getResult(); + } + return {}; +} + // Generates amdgpu.mfma/wmma operation on the given inputs for this attribute // type. FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, @@ -719,23 +705,9 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, if (cType != resultType) { return failure(); } - auto getVecOrSingleElem = [&](Value vec) -> Value { - bool one = llvm::cast(vec.getType()).getNumElements() == 1; - return one ? builder.create(loc, vec, 0) : vec; - }; - MMAIntrinsic intrinsic = getIntrinsic().getValue(); - if (is_AMD_MFMA(intrinsic)) { - // MFMA intrinsics want single-element operands of element type, not vector. - lhs = getVecOrSingleElem(lhs); - rhs = getVecOrSingleElem(rhs); - auto [m, n, k] = getMNKShape(); - return builder - .create(loc, resultType, m, n, k, getBlockSize(), lhs, - rhs, acc) - .getResult(); - } else if (is_AMD_WMMA(intrinsic)) { - return builder.create(loc, resultType, lhs, rhs, acc) - .getResult(); + if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(), + resultType, lhs, rhs, acc)) { + return value; } return failure(); } @@ -1168,23 +1140,18 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, SmallVector intrinsicsAcc = distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle); - // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation - // to create the target intrinsics. - auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue()); - auto [intrinsicAType, intrinsicBType, intrinsicCType] = - intrinsicMma.getABCVectorTypes(); + MMAIntrinsic intrinsic = getIntrinsic().getValue(); + VectorType intrinCType = + getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc); // Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics. for (int mu = 0; mu < getUnrollM(); ++mu) { for (int nu = 0; nu < getUnrollN(); ++nu) { for (int ku = 0; ku < getUnrollK(); ++ku) { - // Assume intrinsicMma.buildMmaOperation() success: validation should be - // completed prior to mutating IR. Value lhs = intrinsicsLhs[mu * getUnrollK() + ku]; Value rhs = intrinsicsRhs[nu * getUnrollK() + ku]; Value &acc = intrinsicsAcc[mu * getUnrollN() + nu]; - acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs, - rhs, acc); + acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc); } } }