Skip to content

Commit

Permalink
[python][tuner] Add bindings for iree_codegen.translation_info (ire…
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored Nov 13, 2024
1 parent ea03080 commit bc23e59
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 1 deletion.
20 changes: 20 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ MLIR_CAPI_EXPORTED
uint32_t
ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool
ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID(void);

struct ireeCodegenTranslationInfoParameters {
MlirAttribute passPipeline; // DispatchLoweringPassPipelineAttr.
MlirAttribute codegenSpec; // Optional SymbolRefAttr.
const int64_t *workgroupSize; // Optional ArrayRef<int64_t>.
size_t numWorkgroupSizeElements; // Size of the ArrayRef above.
int64_t subgroupSize; // Optional int64_t.
MlirAttribute configuration; // Optional DictionaryAttr.
};

MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet(
MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters);

MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
76 changes: 76 additions & 0 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include <optional>
#include <vector>
#include "iree/compiler/dialects/iree_codegen.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
Expand Down Expand Up @@ -50,6 +52,80 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
.attr("DispatchLoweringPassPipeline")(rawValue);
});

//===-------------------------------------------------------------------===//
// CodegenTranslationInfoAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_codegen_module, "TranslationInfoAttr",
ireeAttributeIsACodegenTranslationInfoAttr,
ireeCodegenTranslationInfoAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, MlirAttribute passPipeline,
std::optional<MlirAttribute> codegenSpec,
std::optional<std::vector<int64_t>> workgroupSize,
std::optional<int64_t> subgroupSize,
std::optional<MlirAttribute> configuration, MlirContext ctx) {
ireeCodegenTranslationInfoParameters parameters = {};
parameters.passPipeline = passPipeline;
parameters.codegenSpec =
codegenSpec.value_or(mlirAttributeGetNull());
if (workgroupSize.has_value()) {
parameters.workgroupSize = workgroupSize->data();
parameters.numWorkgroupSizeElements = workgroupSize->size();
}
parameters.subgroupSize = subgroupSize.value_or(0);
parameters.configuration =
configuration.value_or(mlirAttributeGetNull());

return ireeCodegenTranslationInfoAttrGet(ctx, parameters);
},
"cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(),
"workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(),
"configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(),
"Gets an #iree_codegen.translation_info from "
"parameters.")
.def_property_readonly(
"pass_pipeline",
[](MlirAttribute self) -> MlirAttribute {
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
return parameters.passPipeline;
})
.def_property_readonly(
"codegen_spec",
[](MlirAttribute self) -> std::optional<MlirAttribute> {
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
if (mlirAttributeIsNull(parameters.codegenSpec)) {
return std::nullopt;
}
return parameters.codegenSpec;
})
.def_property_readonly(
"workgroup_size",
[](MlirAttribute self) -> std::vector<int64_t> {
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
return {parameters.workgroupSize,
parameters.workgroupSize +
parameters.numWorkgroupSizeElements};
})
.def_property_readonly(
"subgroup_size",
[](MlirAttribute self) -> int64_t {
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
return parameters.subgroupSize;
})
.def_property_readonly(
"configuration",
[](MlirAttribute self) -> std::optional<MlirAttribute> {
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
if (mlirAttributeIsNull(parameters.configuration)) {
return std::nullopt;
}
return parameters.configuration;
});

//===--------------------------------------------------------------------===//

auto iree_gpu_module =
m.def_submodule("iree_gpu", "iree_gpu dialect bindings");

Expand Down
49 changes: 49 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,55 @@ def codegen_dispatch_lowering_pass_pipeline():
assert "LLVMGPUTileAndFuse" in str(pipeline_attr)


@run
def codegen_translation_info_minimal():
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.None_
)
translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr)
assert translation_info is not None
assert str(translation_info) == "#iree_codegen.translation_info<pipeline = None>"
assert translation_info.pass_pipeline == pipeline_attr
assert translation_info.codegen_spec is None
assert translation_info.workgroup_size == []
assert translation_info.subgroup_size == 0
assert translation_info.configuration is None


@run
def codegen_translation_info_with_sizes():
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.Custom
)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [64, 4, 1], 32
)
assert translation_info is not None
assert translation_info.pass_pipeline == pipeline_attr
assert translation_info.codegen_spec is None
assert translation_info.workgroup_size == [64, 4, 1]
assert translation_info.subgroup_size == 32
assert translation_info.configuration is None


@run
def codegen_translation_info_full():
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.TransformDialectCodegen
)
foo_symbol = ir.SymbolRefAttr.get(["foo"])
configuration = ir.DictAttr.get({"A": ir.IntegerAttr.get(ir.IndexType.get(), 42)})
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, foo_symbol, [128], 32, configuration
)
assert translation_info is not None
assert translation_info.pass_pipeline == pipeline_attr
assert translation_info.codegen_spec == foo_symbol
assert translation_info.workgroup_size == [128]
assert translation_info.subgroup_size == 32
assert translation_info.configuration == configuration


# ======================================================================
# IREE GPU Dialect
# ======================================================================
Expand Down
67 changes: 67 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include <cstdint>
#include <optional>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinAttributes.h"

using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;

bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
MlirAttribute attr) {
Expand Down Expand Up @@ -42,3 +47,65 @@ ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr) {
return static_cast<uint32_t>(
llvm::cast<DispatchLoweringPassPipelineAttr>(unwrap(attr)).getValue());
}

bool ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr) {
return llvm::isa<TranslationInfoAttr>(unwrap(attr));
}

MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID() {
return wrap(TranslationInfoAttr::getTypeID());
}

MlirAttribute ireeCodegenTranslationInfoAttrGet(
MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters) {
assert(!mlirAttributeIsNull(parameters.passPipeline) &&
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
parameters.passPipeline) &&
"Invalid pass pipeline attr");

assert((mlirAttributeIsNull(parameters.codegenSpec) ||
mlirAttributeIsASymbolRef(parameters.codegenSpec)) &&
"Invalid codegen spec attr");

assert((mlirAttributeIsNull(parameters.configuration) ||
mlirAttributeIsADictionary(parameters.configuration)) &&
"Invalid configuration attr");

DispatchLoweringPassPipeline passPipeline =
llvm::cast<DispatchLoweringPassPipelineAttr>(
unwrap(parameters.passPipeline))
.getValue();
auto codegenSpec = llvm::cast_if_present<mlir::SymbolRefAttr>(
unwrap(parameters.codegenSpec));

llvm::ArrayRef<int64_t> workgroupSize;
if (parameters.workgroupSize) {
workgroupSize = {parameters.workgroupSize,
parameters.numWorkgroupSizeElements};
}

std::optional<int64_t> subgroupSize = parameters.subgroupSize;
auto configuration = llvm::cast_if_present<mlir::DictionaryAttr>(
unwrap(parameters.configuration));

mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(TranslationInfoAttr::get(ctx, passPipeline, codegenSpec,
workgroupSize, subgroupSize,
configuration));
}

ireeCodegenTranslationInfoParameters
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) {
auto translationInfo = llvm::cast<TranslationInfoAttr>(unwrap(attr));

ireeCodegenTranslationInfoParameters parameters = {};
parameters.passPipeline = wrap(translationInfo.getPassPipeline());
parameters.codegenSpec = wrap(translationInfo.getCodegenSpec());
llvm::ArrayRef<int64_t> workgroupSize = translationInfo.getWorkgroupSize();
parameters.workgroupSize = workgroupSize.data();
parameters.numWorkgroupSizeElements = workgroupSize.size();
parameters.subgroupSize = translationInfo.getSubgroupSize();
parameters.configuration = wrap(translationInfo.getConfiguration());

return parameters;
}
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <stdint.h>

extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr();
extern void ireeAttributeIsACodegenTranslationInfoAttr();
extern void ireeAttributeIsAGPULoweringConfigAttr();
extern void ireeAttributeIsAGPUMMAAttr();
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
Expand All @@ -19,6 +20,9 @@ extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
extern void ireeCodegenTranslationInfoAttrGet();
extern void ireeCodegenTranslationInfoAttrGetParameters();
extern void ireeCodegenTranslationInfoAttrGetTypeID();
extern void ireeCompilerEnumeratePlugins();
extern void ireeCompilerEnumerateRegisteredHALTargetBackends();
extern void ireeCompilerErrorDestroy();
Expand Down Expand Up @@ -865,6 +869,7 @@ extern void mlirVectorTypeIsScalable();
uintptr_t __iree_compiler_hidden_force_extern() {
uintptr_t x = 0;
x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr;
x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr;
x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr;
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
Expand All @@ -873,6 +878,9 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGet;
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetParameters;
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetTypeID;
x += (uintptr_t)&ireeCompilerEnumeratePlugins;
x += (uintptr_t)&ireeCompilerEnumerateRegisteredHALTargetBackends;
x += (uintptr_t)&ireeCompilerErrorDestroy;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
; Generated by generate_exports.py: Do not edit.
EXPORTS
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
ireeAttributeIsACodegenTranslationInfoAttr
ireeAttributeIsAGPULoweringConfigAttr
ireeAttributeIsAGPUMMAAttr
ireeAttributeIsAGPUMMAIntrinsicAttr
Expand All @@ -9,6 +10,9 @@ EXPORTS
ireeCodegenDispatchLoweringPassPipelineAttrGet
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
ireeCodegenTranslationInfoAttrGet
ireeCodegenTranslationInfoAttrGetParameters
ireeCodegenTranslationInfoAttrGetTypeID
ireeCompilerEnumeratePlugins
ireeCompilerEnumerateRegisteredHALTargetBackends
ireeCompilerErrorDestroy
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.ld
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
VER_0 {
global:
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
ireeAttributeIsACodegenTranslationInfoAttr;
ireeAttributeIsAGPULoweringConfigAttr;
ireeAttributeIsAGPUMMAAttr;
ireeAttributeIsAGPUMMAIntrinsicAttr;
Expand All @@ -10,6 +11,9 @@ VER_0 {
ireeCodegenDispatchLoweringPassPipelineAttrGet;
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
ireeCodegenTranslationInfoAttrGet;
ireeCodegenTranslationInfoAttrGetParameters;
ireeCodegenTranslationInfoAttrGetTypeID;
ireeCompilerEnumeratePlugins;
ireeCompilerEnumerateRegisteredHALTargetBackends;
ireeCompilerErrorDestroy;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.macos.lst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by generate_exports.py: Do not edit.
_ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
_ireeAttributeIsACodegenTranslationInfoAttr
_ireeAttributeIsAGPULoweringConfigAttr
_ireeAttributeIsAGPUMMAAttr
_ireeAttributeIsAGPUMMAIntrinsicAttr
Expand All @@ -8,6 +9,9 @@ _ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
_ireeCodegenDispatchLoweringPassPipelineAttrGet
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
_ireeCodegenTranslationInfoAttrGet
_ireeCodegenTranslationInfoAttrGetParameters
_ireeCodegenTranslationInfoAttrGetTypeID
_ireeCompilerEnumeratePlugins
_ireeCompilerEnumerateRegisteredHALTargetBackends
_ireeCompilerErrorDestroy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def IREECodegen_TranslationInfoAttr :
}];

let assemblyFormat = [{
`<` `pipeline` `=` $passPipeline
`<` `pipeline` `=` `` $passPipeline
(`codegen_spec` `=` $codegenSpec^)?
(`workgroup_size` `=` `[` $workgroupSize^ `]`)?
(`subgroup_size` `=` $subgroupSize^)?
Expand Down

0 comments on commit bc23e59

Please sign in to comment.