Skip to content

Commit

Permalink
[python][tuner] Add bindings for iree_codegen.compilation_info
Browse files Browse the repository at this point in the history
Stacked on top of iree-org#19128.

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Nov 13, 2024
1 parent bc23e59 commit 6684e30
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 2 deletions.
16 changes: 16 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet(
MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool
ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID(void);

struct ireeCodegenCompilationInfoParameters {
MlirAttribute loweringConfig;
MlirAttribute translationInfo;
};

MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet(
MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters);

MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
34 changes: 32 additions & 2 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
"cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(),
"workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(),
"configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(),
"Gets an #iree_codegen.translation_info from "
"parameters.")
"Gets an #iree_codegen.translation_info from parameters.")
.def_property_readonly(
"pass_pipeline",
[](MlirAttribute self) -> MlirAttribute {
Expand Down Expand Up @@ -124,6 +123,37 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return parameters.configuration;
});

//===-------------------------------------------------------------------===//
// CodegenCompilationInfoAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_codegen_module, "CompilationInfoAttr",
ireeAttributeIsACodegenCompilationInfoAttr,
ireeCodegenCompilationInfoAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, MlirAttribute loweringConfig,
MlirAttribute translationInfo, MlirContext ctx) {
ireeCodegenCompilationInfoParameters parameters = {};
parameters.loweringConfig = loweringConfig;
parameters.translationInfo = translationInfo;
return ireeCodegenCompilationInfoAttrGet(ctx, parameters);
},
"cls"_a, "lowering_config"_a, "translation_info"_a,
"ctx"_a = py::none(),
"Gets an #iree_codegen.compilation_info from parameters.")
.def_property_readonly(
"lowering_config",
[](MlirAttribute self) -> MlirAttribute {
auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
return parameters.loweringConfig;
})
.def_property_readonly(
"translation_info", [](MlirAttribute self) -> MlirAttribute {
auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
return parameters.translationInfo;
});

//===--------------------------------------------------------------------===//

auto iree_gpu_module =
Expand Down
17 changes: 17 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,20 @@ def lowering_config_attr():
assert lowering_config is not None

assert lowering_config.attributes == attributes


@run
def compilation_info():
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])})
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.None_
)
translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr)

compilation_info = iree_codegen.CompilationInfoAttr.get(
lowering_config, translation_info
)
assert compilation_info is not None
assert compilation_info.lowering_config == lowering_config
assert compilation_info.translation_info == translation_info
40 changes: 40 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
#include <optional>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"

using mlir::iree_compiler::IREE::Codegen::CompilationInfoAttr;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;

bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
Expand Down Expand Up @@ -109,3 +114,38 @@ ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) {

return parameters;
}

bool ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr) {
return llvm::isa<CompilationInfoAttr>(unwrap(attr));
}

MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID() {
return wrap(CompilationInfoAttr::getTypeID());
}

MlirAttribute ireeCodegenCompilationInfoAttrGet(
MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters) {
assert(!mlirAttributeIsNull(parameters.loweringConfig) &&
"Invalid lowering config attr");
assert(
!mlirAttributeIsNull(parameters.translationInfo) &&
ireeAttributeIsACodegenTranslationInfoAttr(parameters.translationInfo) &&
"Invalid translation info attr");

auto loweringConfig = llvm::cast<LoweringConfigAttrInterface>(
unwrap(parameters.loweringConfig));
auto translationInfo =
llvm::cast<TranslationInfoAttr>(unwrap(parameters.translationInfo));

mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(CompilationInfoAttr::get(ctx, loweringConfig, translationInfo));
}

ireeCodegenCompilationInfoParameters
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
auto compilationInfo = llvm::cast<CompilationInfoAttr>(unwrap(attr));
ireeCodegenCompilationInfoParameters parameters = {};
parameters.loweringConfig = wrap(compilationInfo.getLoweringConfig());
parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
return parameters;
}
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@

#include <stdint.h>

extern void ireeAttributeIsACodegenCompilationInfoAttr();
extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr();
extern void ireeAttributeIsACodegenTranslationInfoAttr();
extern void ireeAttributeIsAGPULoweringConfigAttr();
extern void ireeAttributeIsAGPUMMAAttr();
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
extern void ireeAttributeIsAGPUPipelineOptionsAttr();
extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
extern void ireeCodegenCompilationInfoAttrGet();
extern void ireeCodegenCompilationInfoAttrGetParameters();
extern void ireeCodegenCompilationInfoAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
Expand Down Expand Up @@ -868,13 +872,17 @@ extern void mlirVectorTypeIsScalable();

uintptr_t __iree_compiler_hidden_force_extern() {
uintptr_t x = 0;
x += (uintptr_t)&ireeAttributeIsACodegenCompilationInfoAttr;
x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr;
x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr;
x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr;
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
x += (uintptr_t)&ireeAttributeIsAGPUPipelineOptionsAttr;
x += (uintptr_t)&ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGet;
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetParameters;
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetTypeID;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
; Generated by generate_exports.py: Do not edit.
EXPORTS
ireeAttributeIsACodegenCompilationInfoAttr
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
ireeAttributeIsACodegenTranslationInfoAttr
ireeAttributeIsAGPULoweringConfigAttr
ireeAttributeIsAGPUMMAAttr
ireeAttributeIsAGPUMMAIntrinsicAttr
ireeAttributeIsAGPUPipelineOptionsAttr
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
ireeCodegenCompilationInfoAttrGet
ireeCodegenCompilationInfoAttrGetParameters
ireeCodegenCompilationInfoAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGet
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.ld
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Generated by generate_exports.py: Do not edit.
VER_0 {
global:
ireeAttributeIsACodegenCompilationInfoAttr;
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
ireeAttributeIsACodegenTranslationInfoAttr;
ireeAttributeIsAGPULoweringConfigAttr;
ireeAttributeIsAGPUMMAAttr;
ireeAttributeIsAGPUMMAIntrinsicAttr;
ireeAttributeIsAGPUPipelineOptionsAttr;
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
ireeCodegenCompilationInfoAttrGet;
ireeCodegenCompilationInfoAttrGetParameters;
ireeCodegenCompilationInfoAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGet;
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.macos.lst
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Generated by generate_exports.py: Do not edit.
_ireeAttributeIsACodegenCompilationInfoAttr
_ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
_ireeAttributeIsACodegenTranslationInfoAttr
_ireeAttributeIsAGPULoweringConfigAttr
_ireeAttributeIsAGPUMMAAttr
_ireeAttributeIsAGPUMMAIntrinsicAttr
_ireeAttributeIsAGPUPipelineOptionsAttr
_ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
_ireeCodegenCompilationInfoAttrGet
_ireeCodegenCompilationInfoAttrGetParameters
_ireeCodegenCompilationInfoAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGet
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
Expand Down

0 comments on commit 6684e30

Please sign in to comment.