Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPU] Cleanup VirtualMMA functions to match refactoring on base MMAAttr #19144

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
return {};
}

static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
template <typename MMAIntrinsicType>
static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsicType intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
Expand Down Expand Up @@ -369,9 +370,9 @@ getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
PerDimLayoutAttr::get(context, labels[1], shape[1])};
};

static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
auto opaque = getOpaqueMFMALayout(context, intrinsic);
static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
auto opaque = getOpaqueMMALayout(context, intrinsic);
ConcreteMmaLayout concreteLayout;
concreteLayout.base = opaque;
auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs);
Expand Down Expand Up @@ -452,7 +453,7 @@ void MMAAttr::print(AsmPrinter &p) const {
}

MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
auto layout = getOpaqueMFMALayout(context, type);
auto layout = getOpaqueMMALayout(context, type);
return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
layout.nSize, layout.kSize, layout.aType, layout.bType,
layout.cType);
Expand All @@ -466,9 +467,11 @@ std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
return {getMSize(), getNSize(), getKSize()};
}

static VectorType getVectorType(MLIRContext *context, MMAIntrinsic intrinsic,
template <typename MMAIntrinsicType>
static VectorType getVectorType(MLIRContext *context,
MMAIntrinsicType intrinsic,
MMAFragment fragment) {
auto o = getOpaqueMFMALayout(context, intrinsic);
auto o = getOpaqueMMALayout(context, intrinsic);
auto s = getSingleSubgroupLayout(intrinsic, fragment);
Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
: (fragment == MMAFragment::Rhs) ? o.bType
Expand All @@ -491,7 +494,7 @@ FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
VectorLayoutInterface>>
MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
ConcreteMmaLayout layout =
getConcreteMFMALayout(contract->getContext(), getIntrinsic().getValue());
getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue());
return IREE::GPU::getContractionLayout(contract, layout);
}

Expand Down Expand Up @@ -932,13 +935,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle,

std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize * getUnrollM() * getSubgroupsM(),
opaqueLayout.nSize * getUnrollN() * getSubgroupsN(),
opaqueLayout.kSize * getUnrollK()};
Expand Down Expand Up @@ -1228,68 +1231,47 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
return VirtualMMAAttr::get(context, intrinsicAttr);
}

static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context,
VirtualMMAIntrinsic type) {
static std::tuple<Type, Type, Type>
getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);

switch (type) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
// along the k dimension.
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
}
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
return {f16, f16, f32};
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
return {f16, f16, f32};
}
assert(false && "unhandled virtual mma layout type.");
return OpaqueMmaLayout{};
return {};
}

std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
}

std::tuple<VectorType, VectorType, VectorType>
VirtualMMAAttr::getABCVectorTypes() const {
// Check https://github.com/ROCm/amd_matrix_instruction_calculator for
// instruction details. Note here we are returning the number elements, while
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
auto [A, B, C] = getABCElementTypes();
switch (getIntrinsic().getValue()) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
auto aType = VectorType::get({8}, A);
auto bType = VectorType::get({8}, B);
auto cType = VectorType::get({4}, C);
return {aType, bType, cType};
}
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
auto aType = VectorType::get({8}, A);
auto bType = VectorType::get({8}, B);
auto cType = VectorType::get({16}, C);
return {aType, bType, cType};
}
}
assert(false && "unhandled virtual mma layout type.");
return {VectorType{}, VectorType{}, VectorType{}};
MLIRContext *context = getContext();
VirtualMMAIntrinsic intrinsic = getIntrinsic().getValue();
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
return {aVecType, bVecType, cVecType};
}

std::tuple<int64_t, int64_t, int64_t> VirtualMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize};
}

Expand Down
Loading