From 591f63cd2c61c060271b9cc09fe372ed5c80cb2f Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 12 Nov 2024 23:09:37 -0500 Subject: [PATCH] [python][tuner] Add bindings for `iree_codegen.compilation_info` Stacked on top of https://github.com/iree-org/iree/pull/19128. Signed-off-by: Jakub Kuderski --- .../c/iree/compiler/dialects/iree_codegen.h | 16 ++++++++ .../python/IREECompilerDialectsModule.cpp | 34 +++++++++++++++- .../bindings/python/test/ir/dialects_test.py | 17 ++++++++ .../API/Internal/IREECodegenDialectCAPI.cpp | 40 +++++++++++++++++++ compiler/src/iree/compiler/API/api_exports.c | 8 ++++ .../src/iree/compiler/API/api_exports.def | 4 ++ compiler/src/iree/compiler/API/api_exports.ld | 4 ++ .../iree/compiler/API/api_exports.macos.lst | 4 ++ 8 files changed, 125 insertions(+), 2 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index d5bfd570c5ccc..ba6ac4c7de9bf 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -51,6 +51,22 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet( MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool +ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID(void); + +struct ireeCodegenCompilationInfoParameters { + MlirAttribute loweringConfig; + MlirAttribute translationInfo; +}; + +MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet( + MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters); + +MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters +ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index 1f8be3facadc5..7ece224519c2f 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -83,8 +83,7 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) { "cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(), "workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(), "configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(), - "Gets an #iree_codegen.translation_info from " - "parameters.") + "Gets an #iree_codegen.translation_info from parameters.") .def_property_readonly( "pass_pipeline", [](MlirAttribute self) -> MlirAttribute { @@ -124,6 +123,37 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) { return parameters.configuration; }); + //===-------------------------------------------------------------------===// + // CodegenCompilationInfoAttr + //===-------------------------------------------------------------------===// + + mlir_attribute_subclass(iree_codegen_module, "CompilationInfoAttr", + ireeAttributeIsACodegenCompilationInfoAttr, + ireeCodegenCompilationInfoAttrGetTypeID) + .def_classmethod( + "get", + [](const py::object &, MlirAttribute loweringConfig, + MlirAttribute translationInfo, MlirContext ctx) { + ireeCodegenCompilationInfoParameters parameters = {}; + parameters.loweringConfig = loweringConfig; + parameters.translationInfo = translationInfo; + return ireeCodegenCompilationInfoAttrGet(ctx, parameters); + }, + "cls"_a, "lowering_config"_a, "translation_info"_a, + "ctx"_a = py::none(), + "Gets an #iree_codegen.compilation_info from parameters.") + .def_property_readonly( + "lowering_config", + [](MlirAttribute self) -> MlirAttribute { + auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self); + return parameters.loweringConfig; + }) + .def_property_readonly( + "translation_info", [](MlirAttribute self) -> MlirAttribute { + auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self); + return parameters.translationInfo; + }); + //===--------------------------------------------------------------------===// auto iree_gpu_module = diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py index caf32fdd4ad05..4e07a642aba2a 100644 --- a/compiler/bindings/python/test/ir/dialects_test.py +++ b/compiler/bindings/python/test/ir/dialects_test.py @@ -215,3 +215,20 @@ def lowering_config_attr(): assert lowering_config is not None assert lowering_config.attributes == attributes + + +@run +def compilation_info(): + attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])}) + lowering_config = iree_gpu.LoweringConfigAttr.get(attributes) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.None_ + ) + translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr) + + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + assert compilation_info is not None + assert compilation_info.lowering_config == lowering_config + assert compilation_info.translation_info == translation_info diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index 13f2257956365..c295d48b01e34 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -9,15 +9,20 @@ #include #include #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/dialects/iree_codegen.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +using mlir::iree_compiler::IREE::Codegen::CompilationInfoAttr; using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline; using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr; +using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface; using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr; bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr( @@ -109,3 +114,38 @@ ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) { return parameters; } + +bool ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID() { + return wrap(CompilationInfoAttr::getTypeID()); +} + +MlirAttribute ireeCodegenCompilationInfoAttrGet( + MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters) { + assert(!mlirAttributeIsNull(parameters.loweringConfig) && + "Invalid lowering config attr"); + assert( + !mlirAttributeIsNull(parameters.translationInfo) && + ireeAttributeIsACodegenTranslationInfoAttr(parameters.translationInfo) && + "Invalid translation info attr"); + + auto loweringConfig = llvm::cast( + unwrap(parameters.loweringConfig)); + auto translationInfo = + llvm::cast(unwrap(parameters.translationInfo)); + + mlir::MLIRContext *ctx = unwrap(mlirCtx); + return wrap(CompilationInfoAttr::get(ctx, loweringConfig, translationInfo)); +} + +ireeCodegenCompilationInfoParameters +ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) { + auto compilationInfo = llvm::cast(unwrap(attr)); + ireeCodegenCompilationInfoParameters parameters = {}; + parameters.loweringConfig = wrap(compilationInfo.getLoweringConfig()); + parameters.translationInfo = wrap(compilationInfo.getTranslationInfo()); + return parameters; +} diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 9105596712ff9..ffb8f086c6783 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -10,6 +10,7 @@ #include +extern void ireeAttributeIsACodegenCompilationInfoAttr(); extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(); extern void ireeAttributeIsACodegenTranslationInfoAttr(); extern void ireeAttributeIsAGPULoweringConfigAttr(); @@ -17,6 +18,9 @@ extern void ireeAttributeIsAGPUMMAAttr(); extern void ireeAttributeIsAGPUMMAIntrinsicAttr(); extern void ireeAttributeIsAGPUPipelineOptionsAttr(); extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(); +extern void ireeCodegenCompilationInfoAttrGet(); +extern void ireeCodegenCompilationInfoAttrGetParameters(); +extern void ireeCodegenCompilationInfoAttrGetTypeID(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGet(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue(); @@ -868,6 +872,7 @@ extern void mlirVectorTypeIsScalable(); uintptr_t __iree_compiler_hidden_force_extern() { uintptr_t x = 0; + x += (uintptr_t)&ireeAttributeIsACodegenCompilationInfoAttr; x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr; x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr; x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr; @@ -875,6 +880,9 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr; x += (uintptr_t)&ireeAttributeIsAGPUPipelineOptionsAttr; x += (uintptr_t)&ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr; + x += (uintptr_t)&ireeCodegenCompilationInfoAttrGet; + x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetParameters; + x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetTypeID; x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet; x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID; x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index 9844a9f44bded..ed5e12cceb489 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -1,5 +1,6 @@ ; Generated by generate_exports.py: Do not edit. EXPORTS + ireeAttributeIsACodegenCompilationInfoAttr ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr ireeAttributeIsACodegenTranslationInfoAttr ireeAttributeIsAGPULoweringConfigAttr @@ -7,6 +8,9 @@ EXPORTS ireeAttributeIsAGPUMMAIntrinsicAttr ireeAttributeIsAGPUPipelineOptionsAttr ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr + ireeCodegenCompilationInfoAttrGet + ireeCodegenCompilationInfoAttrGetParameters + ireeCodegenCompilationInfoAttrGetTypeID ireeCodegenDispatchLoweringPassPipelineAttrGet ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID ireeCodegenDispatchLoweringPassPipelineAttrGetValue diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 0c360271ff3c3..0808927de5275 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -1,6 +1,7 @@ # Generated by generate_exports.py: Do not edit. VER_0 { global: + ireeAttributeIsACodegenCompilationInfoAttr; ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr; ireeAttributeIsACodegenTranslationInfoAttr; ireeAttributeIsAGPULoweringConfigAttr; @@ -8,6 +9,9 @@ VER_0 { ireeAttributeIsAGPUMMAIntrinsicAttr; ireeAttributeIsAGPUPipelineOptionsAttr; ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr; + ireeCodegenCompilationInfoAttrGet; + ireeCodegenCompilationInfoAttrGetParameters; + ireeCodegenCompilationInfoAttrGetTypeID; ireeCodegenDispatchLoweringPassPipelineAttrGet; ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID; ireeCodegenDispatchLoweringPassPipelineAttrGetValue; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index d0683e846f4a6..11169bf3f13d9 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -1,4 +1,5 @@ # Generated by generate_exports.py: Do not edit. +_ireeAttributeIsACodegenCompilationInfoAttr _ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr _ireeAttributeIsACodegenTranslationInfoAttr _ireeAttributeIsAGPULoweringConfigAttr @@ -6,6 +7,9 @@ _ireeAttributeIsAGPUMMAAttr _ireeAttributeIsAGPUMMAIntrinsicAttr _ireeAttributeIsAGPUPipelineOptionsAttr _ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr +_ireeCodegenCompilationInfoAttrGet +_ireeCodegenCompilationInfoAttrGetParameters +_ireeCodegenCompilationInfoAttrGetTypeID _ireeCodegenDispatchLoweringPassPipelineAttrGet _ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID _ireeCodegenDispatchLoweringPassPipelineAttrGetValue