Skip to content

Commit

Permalink
[LLVMGPU] Cleanup VirtualMMA functions to match refactoring on base M…
Browse files Browse the repository at this point in the history
…MAAttr (#19144)

Apply similar cleanups to what is done in
#19098. For most part we do:
1. Templateize getVectorType and getOpaqueMmaLayout to work on any
intrinsic
2. Use common getOpaqueMmaLayout for VirtualMMA
3. Update getABCElementTypes to be similar to MMAAttr
4. Rename get.*MFMA fn to get.*MMA since MFMA is CDNA specific but in
   reality it does not have to be MFMA instructions

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Nov 14, 2024
1 parent a70ea83 commit ef241f9
Showing 1 changed file with 33 additions and 51 deletions.
84 changes: 33 additions & 51 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
return {};
}

static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
template <typename MMAIntrinsicType>
static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsicType intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
Expand Down Expand Up @@ -369,9 +370,9 @@ getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
PerDimLayoutAttr::get(context, labels[1], shape[1])};
};

static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
auto opaque = getOpaqueMFMALayout(context, intrinsic);
static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
auto opaque = getOpaqueMMALayout(context, intrinsic);
ConcreteMmaLayout concreteLayout;
concreteLayout.base = opaque;
auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs);
Expand Down Expand Up @@ -452,7 +453,7 @@ void MMAAttr::print(AsmPrinter &p) const {
}

MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
auto layout = getOpaqueMFMALayout(context, type);
auto layout = getOpaqueMMALayout(context, type);
return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
layout.nSize, layout.kSize, layout.aType, layout.bType,
layout.cType);
Expand All @@ -466,9 +467,11 @@ std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
return {getMSize(), getNSize(), getKSize()};
}

static VectorType getVectorType(MLIRContext *context, MMAIntrinsic intrinsic,
template <typename MMAIntrinsicType>
static VectorType getVectorType(MLIRContext *context,
MMAIntrinsicType intrinsic,
MMAFragment fragment) {
auto o = getOpaqueMFMALayout(context, intrinsic);
auto o = getOpaqueMMALayout(context, intrinsic);
auto s = getSingleSubgroupLayout(intrinsic, fragment);
Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
: (fragment == MMAFragment::Rhs) ? o.bType
Expand All @@ -491,7 +494,7 @@ FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
VectorLayoutInterface>>
MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
ConcreteMmaLayout layout =
getConcreteMFMALayout(contract->getContext(), getIntrinsic().getValue());
getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue());
return IREE::GPU::getContractionLayout(contract, layout);
}

Expand Down Expand Up @@ -932,13 +935,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle,

std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize * getUnrollM() * getSubgroupsM(),
opaqueLayout.nSize * getUnrollN() * getSubgroupsN(),
opaqueLayout.kSize * getUnrollK()};
Expand Down Expand Up @@ -1228,68 +1231,47 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
return VirtualMMAAttr::get(context, intrinsicAttr);
}

static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context,
VirtualMMAIntrinsic type) {
static std::tuple<Type, Type, Type>
getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);

switch (type) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
// along the k dimension.
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
return {f16, f16, f32};
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
return {f16, f16, f32};
}
assert(false && "unhandled virtual mma layout type.");
return OpaqueMmaLayout{};
return {};
}

std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<VectorType, VectorType, VectorType>
VirtualMMAAttr::getABCVectorTypes() const {
// Check https://github.com/ROCm/amd_matrix_instruction_calculator for
// instruction details. Note here we are returning the number elements, while
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
auto [A, B, C] = getABCElementTypes();
switch (getIntrinsic().getValue()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
auto aType = VectorType::get({8}, A);
auto bType = VectorType::get({8}, B);
auto cType = VectorType::get({4}, C);
return {aType, bType, cType};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
auto aType = VectorType::get({8}, A);
auto bType = VectorType::get({8}, B);
auto cType = VectorType::get({16}, C);
return {aType, bType, cType};
}
}
assert(false && "unhandled virtual mma layout type.");
return {VectorType{}, VectorType{}, VectorType{}};
MLIRContext *context = getContext();
VirtualMMAIntrinsic intrinsic = getIntrinsic().getValue();
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
return {aVecType, bVecType, cVecType};
}

std::tuple<int64_t, int64_t, int64_t> VirtualMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize};
}

Expand Down

0 comments on commit ef241f9

Please sign in to comment.