From bc23e5959d089898715aef215659468ef6b3dd49 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 13 Nov 2024 10:08:40 -0500 Subject: [PATCH] [python][tuner] Add bindings for `iree_codegen.translation_info` (#19128) --- .../c/iree/compiler/dialects/iree_codegen.h | 20 +++++ .../python/IREECompilerDialectsModule.cpp | 76 +++++++++++++++++++ .../bindings/python/test/ir/dialects_test.py | 49 ++++++++++++ .../API/Internal/IREECodegenDialectCAPI.cpp | 67 ++++++++++++++++ compiler/src/iree/compiler/API/api_exports.c | 8 ++ .../src/iree/compiler/API/api_exports.def | 4 + compiler/src/iree/compiler/API/api_exports.ld | 4 + .../iree/compiler/API/api_exports.macos.lst | 4 + .../Dialect/Codegen/IR/IREECodegenAttrs.td | 2 +- 9 files changed, 233 insertions(+), 1 deletion(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index 55fb89f1db19..357ac87038be 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -30,6 +30,26 @@ MLIR_CAPI_EXPORTED uint32_t ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool +ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID(void); + +struct ireeCodegenTranslationInfoParameters { + MlirAttribute passPipeline; // DispatchLoweringPassPipelineAttr. + MlirAttribute codegenSpec; // Optional SymbolRefAttr. + const int64_t *workgroupSize; // Optional ArrayRef. + size_t numWorkgroupSizeElements; // Size of the ArrayRef above. + int64_t subgroupSize; // Optional int64_t. + MlirAttribute configuration; // Optional DictionaryAttr. +}; + +MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet( + MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters); + +MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters +ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index 8991f281cf95..1f8be3facadc 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -5,6 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include +#include +#include #include "iree/compiler/dialects/iree_codegen.h" #include "iree/compiler/dialects/iree_gpu.h" #include "mlir-c/BuiltinAttributes.h" @@ -50,6 +52,80 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) { .attr("DispatchLoweringPassPipeline")(rawValue); }); + //===-------------------------------------------------------------------===// + // CodegenTranslationInfoAttr + //===-------------------------------------------------------------------===// + + mlir_attribute_subclass(iree_codegen_module, "TranslationInfoAttr", + ireeAttributeIsACodegenTranslationInfoAttr, + ireeCodegenTranslationInfoAttrGetTypeID) + .def_classmethod( + "get", + [](const py::object &, MlirAttribute passPipeline, + std::optional codegenSpec, + std::optional> workgroupSize, + std::optional subgroupSize, + std::optional configuration, MlirContext ctx) { + ireeCodegenTranslationInfoParameters parameters = {}; + parameters.passPipeline = passPipeline; + parameters.codegenSpec = + codegenSpec.value_or(mlirAttributeGetNull()); + if (workgroupSize.has_value()) { + parameters.workgroupSize = workgroupSize->data(); + parameters.numWorkgroupSizeElements = workgroupSize->size(); + } + parameters.subgroupSize = subgroupSize.value_or(0); + parameters.configuration = + configuration.value_or(mlirAttributeGetNull()); + + return ireeCodegenTranslationInfoAttrGet(ctx, parameters); + }, + "cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(), + "workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(), + "configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(), + "Gets an #iree_codegen.translation_info from " + "parameters.") + .def_property_readonly( + "pass_pipeline", + [](MlirAttribute self) -> MlirAttribute { + auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self); + return parameters.passPipeline; + }) + .def_property_readonly( + "codegen_spec", + [](MlirAttribute self) -> std::optional { + auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self); + if (mlirAttributeIsNull(parameters.codegenSpec)) { + return std::nullopt; + } + return parameters.codegenSpec; + }) + .def_property_readonly( + "workgroup_size", + [](MlirAttribute self) -> std::vector { + auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self); + return {parameters.workgroupSize, + parameters.workgroupSize + + parameters.numWorkgroupSizeElements}; + }) + .def_property_readonly( + "subgroup_size", + [](MlirAttribute self) -> int64_t { + auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self); + return parameters.subgroupSize; + }) + .def_property_readonly( + "configuration", + [](MlirAttribute self) -> std::optional { + auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self); + if (mlirAttributeIsNull(parameters.configuration)) { + return std::nullopt; + } + return parameters.configuration; + }); + + //===--------------------------------------------------------------------===// + auto iree_gpu_module = m.def_submodule("iree_gpu", "iree_gpu dialect bindings"); diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py index cbd904d3a3f4..378d6ca7d6ee 100644 --- a/compiler/bindings/python/test/ir/dialects_test.py +++ b/compiler/bindings/python/test/ir/dialects_test.py @@ -40,6 +40,55 @@ def codegen_dispatch_lowering_pass_pipeline(): assert "LLVMGPUTileAndFuse" in str(pipeline_attr) +@run +def codegen_translation_info_minimal(): + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.None_ + ) + translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr) + assert translation_info is not None + assert str(translation_info) == "#iree_codegen.translation_info" + assert translation_info.pass_pipeline == pipeline_attr + assert translation_info.codegen_spec is None + assert translation_info.workgroup_size == [] + assert translation_info.subgroup_size == 0 + assert translation_info.configuration is None + + +@run +def codegen_translation_info_with_sizes(): + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.Custom + ) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [64, 4, 1], 32 + ) + assert translation_info is not None + assert translation_info.pass_pipeline == pipeline_attr + assert translation_info.codegen_spec is None + assert translation_info.workgroup_size == [64, 4, 1] + assert translation_info.subgroup_size == 32 + assert translation_info.configuration is None + + +@run +def codegen_translation_info_full(): + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.TransformDialectCodegen + ) + foo_symbol = ir.SymbolRefAttr.get(["foo"]) + configuration = ir.DictAttr.get({"A": ir.IntegerAttr.get(ir.IndexType.get(), 42)}) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, foo_symbol, [128], 32, configuration + ) + assert translation_info is not None + assert translation_info.pass_pipeline == pipeline_attr + assert translation_info.codegen_spec == foo_symbol + assert translation_info.workgroup_size == [128] + assert translation_info.subgroup_size == 32 + assert translation_info.configuration == configuration + + # ====================================================================== # IREE GPU Dialect # ====================================================================== diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index f0877d5374f2..13f225795636 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -4,16 +4,21 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include +#include #include #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/dialects/iree_codegen.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinAttributes.h" using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline; using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr; +using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr; bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr( MlirAttribute attr) { @@ -42,3 +47,65 @@ ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr) { return static_cast( llvm::cast(unwrap(attr)).getValue()); } + +bool ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID() { + return wrap(TranslationInfoAttr::getTypeID()); +} + +MlirAttribute ireeCodegenTranslationInfoAttrGet( + MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters) { + assert(!mlirAttributeIsNull(parameters.passPipeline) && + ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr( + parameters.passPipeline) && + "Invalid pass pipeline attr"); + + assert((mlirAttributeIsNull(parameters.codegenSpec) || + mlirAttributeIsASymbolRef(parameters.codegenSpec)) && + "Invalid codegen spec attr"); + + assert((mlirAttributeIsNull(parameters.configuration) || + mlirAttributeIsADictionary(parameters.configuration)) && + "Invalid configuration attr"); + + DispatchLoweringPassPipeline passPipeline = + llvm::cast( + unwrap(parameters.passPipeline)) + .getValue(); + auto codegenSpec = llvm::cast_if_present( + unwrap(parameters.codegenSpec)); + + llvm::ArrayRef workgroupSize; + if (parameters.workgroupSize) { + workgroupSize = {parameters.workgroupSize, + parameters.numWorkgroupSizeElements}; + } + + std::optional subgroupSize = parameters.subgroupSize; + auto configuration = llvm::cast_if_present( + unwrap(parameters.configuration)); + + mlir::MLIRContext *ctx = unwrap(mlirCtx); + return wrap(TranslationInfoAttr::get(ctx, passPipeline, codegenSpec, + workgroupSize, subgroupSize, + configuration)); +} + +ireeCodegenTranslationInfoParameters +ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) { + auto translationInfo = llvm::cast(unwrap(attr)); + + ireeCodegenTranslationInfoParameters parameters = {}; + parameters.passPipeline = wrap(translationInfo.getPassPipeline()); + parameters.codegenSpec = wrap(translationInfo.getCodegenSpec()); + llvm::ArrayRef workgroupSize = translationInfo.getWorkgroupSize(); + parameters.workgroupSize = workgroupSize.data(); + parameters.numWorkgroupSizeElements = workgroupSize.size(); + parameters.subgroupSize = translationInfo.getSubgroupSize(); + parameters.configuration = wrap(translationInfo.getConfiguration()); + + return parameters; +} diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 56628bba92d5..9105596712ff 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -11,6 +11,7 @@ #include extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(); +extern void ireeAttributeIsACodegenTranslationInfoAttr(); extern void ireeAttributeIsAGPULoweringConfigAttr(); extern void ireeAttributeIsAGPUMMAAttr(); extern void ireeAttributeIsAGPUMMAIntrinsicAttr(); @@ -19,6 +20,9 @@ extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGet(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue(); +extern void ireeCodegenTranslationInfoAttrGet(); +extern void ireeCodegenTranslationInfoAttrGetParameters(); +extern void ireeCodegenTranslationInfoAttrGetTypeID(); extern void ireeCompilerEnumeratePlugins(); extern void ireeCompilerEnumerateRegisteredHALTargetBackends(); extern void ireeCompilerErrorDestroy(); @@ -865,6 +869,7 @@ extern void mlirVectorTypeIsScalable(); uintptr_t __iree_compiler_hidden_force_extern() { uintptr_t x = 0; x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr; + x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr; x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr; x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr; x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr; @@ -873,6 +878,9 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet; x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID; x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue; + x += (uintptr_t)&ireeCodegenTranslationInfoAttrGet; + x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetParameters; + x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetTypeID; x += (uintptr_t)&ireeCompilerEnumeratePlugins; x += (uintptr_t)&ireeCompilerEnumerateRegisteredHALTargetBackends; x += (uintptr_t)&ireeCompilerErrorDestroy; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index 8ce0e5c228fe..9844a9f44bde 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -1,6 +1,7 @@ ; Generated by generate_exports.py: Do not edit. EXPORTS ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr + ireeAttributeIsACodegenTranslationInfoAttr ireeAttributeIsAGPULoweringConfigAttr ireeAttributeIsAGPUMMAAttr ireeAttributeIsAGPUMMAIntrinsicAttr @@ -9,6 +10,9 @@ EXPORTS ireeCodegenDispatchLoweringPassPipelineAttrGet ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID ireeCodegenDispatchLoweringPassPipelineAttrGetValue + ireeCodegenTranslationInfoAttrGet + ireeCodegenTranslationInfoAttrGetParameters + ireeCodegenTranslationInfoAttrGetTypeID ireeCompilerEnumeratePlugins ireeCompilerEnumerateRegisteredHALTargetBackends ireeCompilerErrorDestroy diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 4f5ea6ba7caf..0c360271ff3c 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -2,6 +2,7 @@ VER_0 { global: ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr; + ireeAttributeIsACodegenTranslationInfoAttr; ireeAttributeIsAGPULoweringConfigAttr; ireeAttributeIsAGPUMMAAttr; ireeAttributeIsAGPUMMAIntrinsicAttr; @@ -10,6 +11,9 @@ VER_0 { ireeCodegenDispatchLoweringPassPipelineAttrGet; ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID; ireeCodegenDispatchLoweringPassPipelineAttrGetValue; + ireeCodegenTranslationInfoAttrGet; + ireeCodegenTranslationInfoAttrGetParameters; + ireeCodegenTranslationInfoAttrGetTypeID; ireeCompilerEnumeratePlugins; ireeCompilerEnumerateRegisteredHALTargetBackends; ireeCompilerErrorDestroy; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index 8f36b215d2b5..d0683e846f4a 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -1,5 +1,6 @@ # Generated by generate_exports.py: Do not edit. _ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr +_ireeAttributeIsACodegenTranslationInfoAttr _ireeAttributeIsAGPULoweringConfigAttr _ireeAttributeIsAGPUMMAAttr _ireeAttributeIsAGPUMMAIntrinsicAttr @@ -8,6 +9,9 @@ _ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr _ireeCodegenDispatchLoweringPassPipelineAttrGet _ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID _ireeCodegenDispatchLoweringPassPipelineAttrGetValue +_ireeCodegenTranslationInfoAttrGet +_ireeCodegenTranslationInfoAttrGetParameters +_ireeCodegenTranslationInfoAttrGetTypeID _ireeCompilerEnumeratePlugins _ireeCompilerEnumerateRegisteredHALTargetBackends _ireeCompilerErrorDestroy diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 80b4b12d7d5f..3086c09b2069 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -226,7 +226,7 @@ def IREECodegen_TranslationInfoAttr : }]; let assemblyFormat = [{ - `<` `pipeline` `=` $passPipeline + `<` `pipeline` `=` `` $passPipeline (`codegen_spec` `=` $codegenSpec^)? (`workgroup_size` `=` `[` $workgroupSize^ `]`)? (`subgroup_size` `=` $subgroupSize^)?