Skip to content

Commit

Permalink
[python][tuner] Add bindings for MMAIntrinsic (#19095)
Browse files Browse the repository at this point in the history
Expose MMA intrinsics to python. This includes both the `MMAIntrinsic`
attribute and the enum values. In addition, expose element/vector types
and shapes from the MMA interface (but not the interface itself).

This is so that the tuner can use the compiler as the SoT for all the
intrinsics and use them to parse target properties and generate
constraints.

The enum values are kept opaque so that we do not have to duplicate the
definitions in TD and C API headers.
  • Loading branch information
kuhar authored Nov 11, 2024
1 parent b133218 commit 48f6dee
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 57 deletions.
42 changes: 35 additions & 7 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
extern "C" {
#endif

enum ireeGPUReorderWorkgroupsStrategyEnum {
ireeGPUReorderWorkgroupsStrategyEnumNone = 0,
ireeGPUReorderWorkgroupsStrategyEnumTranspose = 1,
};
// The following C API is **NOT STABLE** and likely to change in the future.
// It mirrors the IREE GPU Dialect which is not stable itself.

MLIR_CAPI_EXPORTED bool
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value);
MLIR_CAPI_EXPORTED MlirAttribute
ireeGPUReorderWorkgroupsStrategyAttrGet(MlirContext mlirCtx, uint32_t value);

MLIR_CAPI_EXPORTED ireeGPUReorderWorkgroupsStrategyEnum
MLIR_CAPI_EXPORTED uint32_t
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED
Expand All @@ -54,6 +52,36 @@ ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID(void);

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx,
uint32_t value);

MLIR_CAPI_EXPORTED uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
uint32_t value);

struct ireeGPUMMAInfo {
MlirType aElementType;
MlirType bElementType;
MlirType cElementType;
MlirType aVectorType;
MlirType bVectorType;
MlirType cVectorType;
int64_t mElements;
int64_t nElements;
int64_t kElements;
};

MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
96 changes: 66 additions & 30 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

static const char *kGpuModuleImportPath =
MAKE_MLIR_PYTHON_QUALNAME("dialects.iree_gpu");

namespace py = pybind11;
using namespace mlir::python::adaptors;

Expand All @@ -22,45 +26,23 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
// GPUReorderWorkgroupsStrategyAttr
//===-------------------------------------------------------------------===//

auto strategyEnum =
py::enum_<ireeGPUReorderWorkgroupsStrategyEnum>(
iree_gpu_module, "ReorderWorkgroupsStrategy", py::module_local())
.value("None_", ireeGPUReorderWorkgroupsStrategyEnumNone)
.value("Transpose", ireeGPUReorderWorkgroupsStrategyEnumTranspose)
.def(
"__str__",
[](ireeGPUReorderWorkgroupsStrategyEnum &self) {
switch (self) {
case ireeGPUReorderWorkgroupsStrategyEnumNone:
return "None";
case ireeGPUReorderWorkgroupsStrategyEnumTranspose:
return "Transpose";
default:
llvm::report_fatal_error(
"unknown ReorderWorkgroupsStrategy variant");
}
},
// pybind overloads are tried in the order they were registered.
// As a result, enums used the default __str__ method instead of
// the custom one. Adding py::prepend() fixes this issue.
py::prepend());

mlir_attribute_subclass(iree_gpu_module, "ReorderWorkgroupsStrategyAttr",
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr,
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, ireeGPUReorderWorkgroupsStrategyEnum value,
MlirContext ctx) {
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUReorderWorkgroupsStrategyAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.reorder_workgroups_strategy from parameters.")
.def_property_readonly(
"value",
[](MlirAttribute self) -> ireeGPUReorderWorkgroupsStrategyEnum {
return ireeGPUReorderWorkgroupsStrategyAttrGetValue(self);
});
.def_property_readonly("raw_value",
ireeGPUReorderWorkgroupsStrategyAttrGetValue)
.def_property_readonly("value", [](MlirAttribute self) -> py::object {
uint32_t rawValue = ireeGPUReorderWorkgroupsStrategyAttrGetValue(self);
return py::module_::import(kGpuModuleImportPath)
.attr("ReorderWorkgroupsStrategy")(rawValue);
});

//===-------------------------------------------------------------------===//
// GPUPipelineOptionsAttr
Expand Down Expand Up @@ -129,4 +111,58 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return attr;
return std::nullopt;
});

//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//
mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUMMAIntrinsicAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma_intrinsic from parameters.")
.def_property_readonly("raw_value", ireeGPUMMAIntrinsicAttrGetValue)
.def_property_readonly("value",
[](MlirAttribute self) -> py::object {
uint32_t rawValue =
ireeGPUMMAIntrinsicAttrGetValue(self);
return py::module_::import(kGpuModuleImportPath)
.attr("MMAIntrinsic")(rawValue);
})
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
uint32_t value = ireeGPUMMAIntrinsicAttrGetValue(self);
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUMMAAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma from parameters.")
.def_property_readonly(
"abc_element_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aElementType, info.bElementType,
info.cElementType);
})
.def_property_readonly(
"abc_vector_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aVectorType, info.bVectorType,
info.cVectorType);
})
.def_property_readonly("mnk_shape", [](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});
}
45 changes: 45 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def gpu_pipeline_options_attr():
reorder_attr = iree_gpu.ReorderWorkgroupsStrategyAttr.get(
iree_gpu.ReorderWorkgroupsStrategy.Transpose, ctx
)
assert reorder_attr.value == iree_gpu.ReorderWorkgroupsStrategy.Transpose

gpu_attr = iree_gpu.PipelineOptionsAttr.get(
True,
False,
Expand Down Expand Up @@ -86,3 +88,46 @@ def gpu_pipeline_options_attr():
# unfortunately not `is`
== iree_gpu.ReorderWorkgroupsStrategy.Transpose
)


@lambda _: _()
def mma_intrinsic_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
mma_intrinsic_attr = iree_gpu.MMAIntrinsicAttr.get(
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ctx
)
assert mma_intrinsic_attr is not None
assert (
str(mma_intrinsic_attr)
== "#iree_gpu<mma_intrinsic MFMA_F32_32x32x8_F16>"
)

raw_value = mma_intrinsic_attr.raw_value
assert raw_value == iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
value = mma_intrinsic_attr.value
assert str(value) == "MFMA_F32_32x32x8_F16"
assert int(value) == raw_value

mma_attr = iree_gpu.MMAAttr.get(raw_value, ctx)
assert mma_attr is not None

f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
a_type, b_type, c_type = mma_attr.abc_element_types
assert a_type == f16
assert b_type == f16
assert c_type == f32

vec_4xf16 = ir.VectorType.get((4,), f16)
a_vec_type, b_vec_type, _c_vec_type = mma_attr.abc_vector_types
assert a_vec_type == vec_4xf16
assert b_vec_type == vec_4xf16

M, N, K = mma_attr.mnk_shape
assert M == 32
assert N == 32
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr
94 changes: 75 additions & 19 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
Expand Down Expand Up @@ -84,20 +87,6 @@ MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}

static_assert(
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::None) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Transpose) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
mlir::iree_compiler::IREE::GPU::
getMaxEnumValForReorderWorkgroupsStrategy(),
"ireeGPUReorderWorkgroupsStrategyEnum and "
"mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
"have diverged");

bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
Expand All @@ -109,8 +98,15 @@ MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
getTypeID());
}

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
static_assert(
std::is_same_v<
uint32_t,
std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>>,
"Enum type changed");

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(MlirContext mlirCtx,
uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
Expand All @@ -119,12 +115,72 @@ MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
value)));
}

ireeGPUReorderWorkgroupsStrategyEnum
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
uint32_t ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
"attr is not a GPUReorderWorkgroupsStrategyAttr");
return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr))
.getValue());
}

bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
unwrap(attr));
}

MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::getTypeID());
}

static_assert(
std::is_same_v<uint32_t, std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::MMAIntrinsic>>,
"Enum type changed");

MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
"attr is not a GPUMMAIntrinsicAttr");
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(unwrap(attr))
.getValue());
}

bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
}

MlirTypeID ireeGPUMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
}

MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));

ireeGPUMMAInfo info = {};
auto [aType, bType, cType] = mma.getABCElementTypes();
info.aElementType = wrap(aType);
info.bElementType = wrap(bType);
info.cElementType = wrap(cType);

auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
info.aVectorType = wrap(aVecType);
info.bVectorType = wrap(bVecType);
info.cVectorType = wrap(cVecType);

std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}
Loading

0 comments on commit 48f6dee

Please sign in to comment.