Skip to content

Commit bc23e59

Browse files
authored
[python][tuner] Add bindings for iree_codegen.translation_info (iree-org#19128)
1 parent ea03080 commit bc23e59

File tree

9 files changed

+233
-1
lines changed

9 files changed

+233
-1
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_codegen.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ MLIR_CAPI_EXPORTED
3030
uint32_t
3131
ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr);
3232

33+
MLIR_CAPI_EXPORTED bool
34+
ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr);
35+
36+
MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID(void);
37+
38+
struct ireeCodegenTranslationInfoParameters {
39+
MlirAttribute passPipeline; // DispatchLoweringPassPipelineAttr.
40+
MlirAttribute codegenSpec; // Optional SymbolRefAttr.
41+
const int64_t *workgroupSize; // Optional ArrayRef<int64_t>.
42+
size_t numWorkgroupSizeElements; // Size of the ArrayRef above.
43+
int64_t subgroupSize; // Optional int64_t.
44+
MlirAttribute configuration; // Optional DictionaryAttr.
45+
};
46+
47+
MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet(
48+
MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters);
49+
50+
MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters
51+
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr);
52+
3353
#ifdef __cplusplus
3454
}
3555
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include <cstdint>
8+
#include <optional>
9+
#include <vector>
810
#include "iree/compiler/dialects/iree_codegen.h"
911
#include "iree/compiler/dialects/iree_gpu.h"
1012
#include "mlir-c/BuiltinAttributes.h"
@@ -50,6 +52,80 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
5052
.attr("DispatchLoweringPassPipeline")(rawValue);
5153
});
5254

55+
//===-------------------------------------------------------------------===//
56+
// CodegenTranslationInfoAttr
57+
//===-------------------------------------------------------------------===//
58+
59+
mlir_attribute_subclass(iree_codegen_module, "TranslationInfoAttr",
60+
ireeAttributeIsACodegenTranslationInfoAttr,
61+
ireeCodegenTranslationInfoAttrGetTypeID)
62+
.def_classmethod(
63+
"get",
64+
[](const py::object &, MlirAttribute passPipeline,
65+
std::optional<MlirAttribute> codegenSpec,
66+
std::optional<std::vector<int64_t>> workgroupSize,
67+
std::optional<int64_t> subgroupSize,
68+
std::optional<MlirAttribute> configuration, MlirContext ctx) {
69+
ireeCodegenTranslationInfoParameters parameters = {};
70+
parameters.passPipeline = passPipeline;
71+
parameters.codegenSpec =
72+
codegenSpec.value_or(mlirAttributeGetNull());
73+
if (workgroupSize.has_value()) {
74+
parameters.workgroupSize = workgroupSize->data();
75+
parameters.numWorkgroupSizeElements = workgroupSize->size();
76+
}
77+
parameters.subgroupSize = subgroupSize.value_or(0);
78+
parameters.configuration =
79+
configuration.value_or(mlirAttributeGetNull());
80+
81+
return ireeCodegenTranslationInfoAttrGet(ctx, parameters);
82+
},
83+
"cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(),
84+
"workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(),
85+
"configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(),
86+
"Gets an #iree_codegen.translation_info from "
87+
"parameters.")
88+
.def_property_readonly(
89+
"pass_pipeline",
90+
[](MlirAttribute self) -> MlirAttribute {
91+
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
92+
return parameters.passPipeline;
93+
})
94+
.def_property_readonly(
95+
"codegen_spec",
96+
[](MlirAttribute self) -> std::optional<MlirAttribute> {
97+
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
98+
if (mlirAttributeIsNull(parameters.codegenSpec)) {
99+
return std::nullopt;
100+
}
101+
return parameters.codegenSpec;
102+
})
103+
.def_property_readonly(
104+
"workgroup_size",
105+
[](MlirAttribute self) -> std::vector<int64_t> {
106+
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
107+
return {parameters.workgroupSize,
108+
parameters.workgroupSize +
109+
parameters.numWorkgroupSizeElements};
110+
})
111+
.def_property_readonly(
112+
"subgroup_size",
113+
[](MlirAttribute self) -> int64_t {
114+
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
115+
return parameters.subgroupSize;
116+
})
117+
.def_property_readonly(
118+
"configuration",
119+
[](MlirAttribute self) -> std::optional<MlirAttribute> {
120+
auto parameters = ireeCodegenTranslationInfoAttrGetParameters(self);
121+
if (mlirAttributeIsNull(parameters.configuration)) {
122+
return std::nullopt;
123+
}
124+
return parameters.configuration;
125+
});
126+
127+
//===--------------------------------------------------------------------===//
128+
53129
auto iree_gpu_module =
54130
m.def_submodule("iree_gpu", "iree_gpu dialect bindings");
55131

compiler/bindings/python/test/ir/dialects_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,55 @@ def codegen_dispatch_lowering_pass_pipeline():
4040
assert "LLVMGPUTileAndFuse" in str(pipeline_attr)
4141

4242

43+
@run
44+
def codegen_translation_info_minimal():
45+
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
46+
iree_codegen.DispatchLoweringPassPipeline.None_
47+
)
48+
translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr)
49+
assert translation_info is not None
50+
assert str(translation_info) == "#iree_codegen.translation_info<pipeline = None>"
51+
assert translation_info.pass_pipeline == pipeline_attr
52+
assert translation_info.codegen_spec is None
53+
assert translation_info.workgroup_size == []
54+
assert translation_info.subgroup_size == 0
55+
assert translation_info.configuration is None
56+
57+
58+
@run
59+
def codegen_translation_info_with_sizes():
60+
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
61+
iree_codegen.DispatchLoweringPassPipeline.Custom
62+
)
63+
translation_info = iree_codegen.TranslationInfoAttr.get(
64+
pipeline_attr, None, [64, 4, 1], 32
65+
)
66+
assert translation_info is not None
67+
assert translation_info.pass_pipeline == pipeline_attr
68+
assert translation_info.codegen_spec is None
69+
assert translation_info.workgroup_size == [64, 4, 1]
70+
assert translation_info.subgroup_size == 32
71+
assert translation_info.configuration is None
72+
73+
74+
@run
75+
def codegen_translation_info_full():
76+
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
77+
iree_codegen.DispatchLoweringPassPipeline.TransformDialectCodegen
78+
)
79+
foo_symbol = ir.SymbolRefAttr.get(["foo"])
80+
configuration = ir.DictAttr.get({"A": ir.IntegerAttr.get(ir.IndexType.get(), 42)})
81+
translation_info = iree_codegen.TranslationInfoAttr.get(
82+
pipeline_attr, foo_symbol, [128], 32, configuration
83+
)
84+
assert translation_info is not None
85+
assert translation_info.pass_pipeline == pipeline_attr
86+
assert translation_info.codegen_spec == foo_symbol
87+
assert translation_info.workgroup_size == [128]
88+
assert translation_info.subgroup_size == 32
89+
assert translation_info.configuration == configuration
90+
91+
4392
# ======================================================================
4493
# IREE GPU Dialect
4594
# ======================================================================

compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include <cassert>
78
#include <cstdint>
9+
#include <optional>
810
#include <type_traits>
911
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1012
#include "iree/compiler/dialects/iree_codegen.h"
13+
#include "mlir-c/BuiltinAttributes.h"
1114
#include "mlir-c/IR.h"
1215
#include "mlir/CAPI/IR.h"
1316
#include "mlir/CAPI/Support.h"
17+
#include "mlir/IR/BuiltinAttributes.h"
1418

1519
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
1620
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
21+
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
1722

1823
bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
1924
MlirAttribute attr) {
@@ -42,3 +47,65 @@ ireeCodegenDispatchLoweringPassPipelineAttrGetValue(MlirAttribute attr) {
4247
return static_cast<uint32_t>(
4348
llvm::cast<DispatchLoweringPassPipelineAttr>(unwrap(attr)).getValue());
4449
}
50+
51+
bool ireeAttributeIsACodegenTranslationInfoAttr(MlirAttribute attr) {
52+
return llvm::isa<TranslationInfoAttr>(unwrap(attr));
53+
}
54+
55+
MlirTypeID ireeCodegenTranslationInfoAttrGetTypeID() {
56+
return wrap(TranslationInfoAttr::getTypeID());
57+
}
58+
59+
MlirAttribute ireeCodegenTranslationInfoAttrGet(
60+
MlirContext mlirCtx, ireeCodegenTranslationInfoParameters parameters) {
61+
assert(!mlirAttributeIsNull(parameters.passPipeline) &&
62+
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
63+
parameters.passPipeline) &&
64+
"Invalid pass pipeline attr");
65+
66+
assert((mlirAttributeIsNull(parameters.codegenSpec) ||
67+
mlirAttributeIsASymbolRef(parameters.codegenSpec)) &&
68+
"Invalid codegen spec attr");
69+
70+
assert((mlirAttributeIsNull(parameters.configuration) ||
71+
mlirAttributeIsADictionary(parameters.configuration)) &&
72+
"Invalid configuration attr");
73+
74+
DispatchLoweringPassPipeline passPipeline =
75+
llvm::cast<DispatchLoweringPassPipelineAttr>(
76+
unwrap(parameters.passPipeline))
77+
.getValue();
78+
auto codegenSpec = llvm::cast_if_present<mlir::SymbolRefAttr>(
79+
unwrap(parameters.codegenSpec));
80+
81+
llvm::ArrayRef<int64_t> workgroupSize;
82+
if (parameters.workgroupSize) {
83+
workgroupSize = {parameters.workgroupSize,
84+
parameters.numWorkgroupSizeElements};
85+
}
86+
87+
std::optional<int64_t> subgroupSize = parameters.subgroupSize;
88+
auto configuration = llvm::cast_if_present<mlir::DictionaryAttr>(
89+
unwrap(parameters.configuration));
90+
91+
mlir::MLIRContext *ctx = unwrap(mlirCtx);
92+
return wrap(TranslationInfoAttr::get(ctx, passPipeline, codegenSpec,
93+
workgroupSize, subgroupSize,
94+
configuration));
95+
}
96+
97+
ireeCodegenTranslationInfoParameters
98+
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) {
99+
auto translationInfo = llvm::cast<TranslationInfoAttr>(unwrap(attr));
100+
101+
ireeCodegenTranslationInfoParameters parameters = {};
102+
parameters.passPipeline = wrap(translationInfo.getPassPipeline());
103+
parameters.codegenSpec = wrap(translationInfo.getCodegenSpec());
104+
llvm::ArrayRef<int64_t> workgroupSize = translationInfo.getWorkgroupSize();
105+
parameters.workgroupSize = workgroupSize.data();
106+
parameters.numWorkgroupSizeElements = workgroupSize.size();
107+
parameters.subgroupSize = translationInfo.getSubgroupSize();
108+
parameters.configuration = wrap(translationInfo.getConfiguration());
109+
110+
return parameters;
111+
}

compiler/src/iree/compiler/API/api_exports.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <stdint.h>
1212

1313
extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr();
14+
extern void ireeAttributeIsACodegenTranslationInfoAttr();
1415
extern void ireeAttributeIsAGPULoweringConfigAttr();
1516
extern void ireeAttributeIsAGPUMMAAttr();
1617
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
@@ -19,6 +20,9 @@ extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
1920
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
2021
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
2122
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
23+
extern void ireeCodegenTranslationInfoAttrGet();
24+
extern void ireeCodegenTranslationInfoAttrGetParameters();
25+
extern void ireeCodegenTranslationInfoAttrGetTypeID();
2226
extern void ireeCompilerEnumeratePlugins();
2327
extern void ireeCompilerEnumerateRegisteredHALTargetBackends();
2428
extern void ireeCompilerErrorDestroy();
@@ -865,6 +869,7 @@ extern void mlirVectorTypeIsScalable();
865869
uintptr_t __iree_compiler_hidden_force_extern() {
866870
uintptr_t x = 0;
867871
x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
872+
x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr;
868873
x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr;
869874
x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr;
870875
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
@@ -873,6 +878,9 @@ uintptr_t __iree_compiler_hidden_force_extern() {
873878
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet;
874879
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
875880
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
881+
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGet;
882+
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetParameters;
883+
x += (uintptr_t)&ireeCodegenTranslationInfoAttrGetTypeID;
876884
x += (uintptr_t)&ireeCompilerEnumeratePlugins;
877885
x += (uintptr_t)&ireeCompilerEnumerateRegisteredHALTargetBackends;
878886
x += (uintptr_t)&ireeCompilerErrorDestroy;

compiler/src/iree/compiler/API/api_exports.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
; Generated by generate_exports.py: Do not edit.
22
EXPORTS
33
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
4+
ireeAttributeIsACodegenTranslationInfoAttr
45
ireeAttributeIsAGPULoweringConfigAttr
56
ireeAttributeIsAGPUMMAAttr
67
ireeAttributeIsAGPUMMAIntrinsicAttr
@@ -9,6 +10,9 @@ EXPORTS
910
ireeCodegenDispatchLoweringPassPipelineAttrGet
1011
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1112
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
13+
ireeCodegenTranslationInfoAttrGet
14+
ireeCodegenTranslationInfoAttrGetParameters
15+
ireeCodegenTranslationInfoAttrGetTypeID
1216
ireeCompilerEnumeratePlugins
1317
ireeCompilerEnumerateRegisteredHALTargetBackends
1418
ireeCompilerErrorDestroy

compiler/src/iree/compiler/API/api_exports.ld

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
VER_0 {
33
global:
44
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
5+
ireeAttributeIsACodegenTranslationInfoAttr;
56
ireeAttributeIsAGPULoweringConfigAttr;
67
ireeAttributeIsAGPUMMAAttr;
78
ireeAttributeIsAGPUMMAIntrinsicAttr;
@@ -10,6 +11,9 @@ VER_0 {
1011
ireeCodegenDispatchLoweringPassPipelineAttrGet;
1112
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
1213
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
14+
ireeCodegenTranslationInfoAttrGet;
15+
ireeCodegenTranslationInfoAttrGetParameters;
16+
ireeCodegenTranslationInfoAttrGetTypeID;
1317
ireeCompilerEnumeratePlugins;
1418
ireeCompilerEnumerateRegisteredHALTargetBackends;
1519
ireeCompilerErrorDestroy;

compiler/src/iree/compiler/API/api_exports.macos.lst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by generate_exports.py: Do not edit.
22
_ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
3+
_ireeAttributeIsACodegenTranslationInfoAttr
34
_ireeAttributeIsAGPULoweringConfigAttr
45
_ireeAttributeIsAGPUMMAAttr
56
_ireeAttributeIsAGPUMMAIntrinsicAttr
@@ -8,6 +9,9 @@ _ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
89
_ireeCodegenDispatchLoweringPassPipelineAttrGet
910
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1011
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
12+
_ireeCodegenTranslationInfoAttrGet
13+
_ireeCodegenTranslationInfoAttrGetParameters
14+
_ireeCodegenTranslationInfoAttrGetTypeID
1115
_ireeCompilerEnumeratePlugins
1216
_ireeCompilerEnumerateRegisteredHALTargetBackends
1317
_ireeCompilerErrorDestroy

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def IREECodegen_TranslationInfoAttr :
226226
}];
227227

228228
let assemblyFormat = [{
229-
`<` `pipeline` `=` $passPipeline
229+
`<` `pipeline` `=` `` $passPipeline
230230
(`codegen_spec` `=` $codegenSpec^)?
231231
(`workgroup_size` `=` `[` $workgroupSize^ `]`)?
232232
(`subgroup_size` `=` $subgroupSize^)?

0 commit comments

Comments
 (0)