Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/TPP/Dialect/Tune/TuneTransformOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TPP_DIALECT_TUNE_TUNETRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

Expand All @@ -11,6 +12,11 @@
namespace mlir {
namespace tune {
void registerTransformDialectExtension(DialectRegistry &registry);

using Handler = std::function<SmallVector<SmallVector<transform::MappedValue>>(
StringRef, SmallVector<SmallVector<transform::MappedValue>>)>;

extern Handler handler;
} // namespace tune
} // namespace mlir

Expand Down
26 changes: 26 additions & 0 deletions include/TPP/Dialect/Tune/TuneTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,30 @@ def TuneSelectOp : Op<Transform_Dialect, "tune.select", [
"$name `from` $options attr-dict `:` type(results)";
}

def TuneCallbackOp : Op<Transform_Dialect, "tune.callback", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Callback to an external handler";
let description = [{
A transform op which will invoke an external handler routine - not further
specified - which is supposed to delegate to the right callback based on
the provided `name`.

Takes any number of arguments of any handle type and produces any number of
results of any handle type. For example,

```mlir
%res:3 = transform.tune.callback @ffi_transform(%op, %param, %value) :
(!transform.any_op, !transform.any_param, !transform.any_value) ->
(!transform.any_param, !transform.any_value, !transform.any_op)
```
}];

let arguments = (ins SymbolRefAttr:$name, Variadic<Transform_AnyHandleOrParamType>:$payloads);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let assemblyFormat =
"$name `(` $payloads `)` attr-dict `:` functional-type(operands, results)";
}

#endif // TUNE_TRANSFORM_OPS
37 changes: 37 additions & 0 deletions lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ using namespace mlir;
#define GET_OP_CLASSES
#include "TPP/Dialect/Tune/TuneTransformOps.cpp.inc"

namespace mlir {
namespace tune {
Handler handler = nullptr;
} // namespace tune
} // namespace mlir

//===----------------------------------------------------------------------===//
// TuneSelectOp
//===----------------------------------------------------------------------===//
Expand All @@ -24,6 +30,37 @@ transform::TuneSelectOp::apply(transform::TransformRewriter &rewriter,
<< "this op does not have interpreted semantics!";
}

//===----------------------------------------------------------------------===//
// TuneCallbackOp
//===----------------------------------------------------------------------===//

void transform::TuneCallbackOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getPayloadsMutable(), // TODO: Make specifiable on the op.
effects);
producesHandle(getOperation()->getOpResults(), effects);
}

DiagnosedSilenceableFailure
transform::TuneCallbackOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (tune::handler == nullptr)
return emitDefiniteFailure()
<< "callback called without a registered callback handler";

SmallVector<SmallVector<MappedValue>> payloads;
detail::prepareValueMappings(payloads, getPayloads(), state);

SmallVector<SmallVector<MappedValue>> res =
tune::handler(getName().getRootReference().getValue(), payloads);

for (auto &&[result, resPayload] : zip_equal(getResults(), res))
results.setMappedValues(llvm::cast<OpResult>(result), resPayload);

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
84 changes: 84 additions & 0 deletions python/TPPDialects.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/DialectRegistry.h"

#include "mlir/Dialect/Transform/IR/TransformTypes.h"

#include "TPP/Dialect/Check/CheckDialect.h"
#include "TPP/Dialect/Perf/PerfDialect.h"
#include "TPP/Dialect/Tune/TuneDialect.h"
Expand All @@ -12,6 +15,11 @@
#include "TPP/Passes.h"

namespace nb = nanobind;
using namespace mlir;

// Global to hold the python callback handler so that it is avialable to be
// called by the C++-callback handler.
nb::object callback_handler;

NB_MODULE(_tppDialects, m) {
auto checkModule = m.def_submodule("check");
Expand Down Expand Up @@ -66,6 +74,82 @@ NB_MODULE(_tppDialects, m) {
},
"registry");

transformTuneModule.def(
"register_callback_handler", [&](nb::callable callable) {
callback_handler =
nb::steal(callable); // TODO: should we ever release this?

// Register a C++ callback that will
// 1) wrap its arguments,
// 2) call a Python callback with the wrapped-up arguments,
// 3) and unwrap the results that the Python callback returned.
tune::handler =
[&](mlir::StringRef name,
SmallVector<SmallVector<transform::MappedValue>> args)
-> SmallVector<SmallVector<transform::MappedValue>> {
// Wrap up the arguments to prepare for passing them to Python.
nb::list pyArgs;
for (auto handleAssociatedValues : args) {
nb::list pyAssociatedValues;

for (auto associatedValue : handleAssociatedValues) {
if (auto *op = dyn_cast<Operation *>(associatedValue)) {
// std::cout << "CALLBACK: pushing op\n";
pyAssociatedValues.append(wrap(op));
} else if (auto value = dyn_cast<Value>(associatedValue)) {
pyAssociatedValues.append(wrap(value));
} else if (auto paramAttr =
dyn_cast<transform::Param>(associatedValue)) {
pyAssociatedValues.append(wrap(paramAttr));
}
}

pyArgs.append(pyAssociatedValues);
}

// The callback to Python code.
auto res = callback_handler(nb::str(name.data()), *pyArgs);

// Needing to do this import here is ... not ideal.
// The below commented-out code is potentially a better solution...
nb::handle mlir_ir = nb::module_::import_("mlir.ir");
nb::handle Operation = mlir_ir.attr("Operation");
nb::handle Value = mlir_ir.attr("Value");
nb::handle Attribute = mlir_ir.attr("Attribute");

// Unwrap the results to prepare for passing them to C++.
SmallVector<SmallVector<transform::MappedValue>> results;
if (nb::isinstance<nb::list>(res) || nb::isinstance<nb::tuple>(res)) {
for (auto assocList : res) {
SmallVector<transform::MappedValue> associatedValues;
for (auto elt : assocList) {
// The following is probably preferable but is broken...
// if (nb::isinstance<MlirValue>(elt)) {
// If `elt` is of the wrong type, isinstance call will crash.
if (nb::isinstance(elt, Value)) {
auto val = nb::cast<MlirValue>(elt);
associatedValues.push_back(unwrap(val));
// The following is probably preferable but is broken...
//} else if (nb::isinstance<MlirOperation>(elt)) {
// If `elt` is of the wrong type, isinstance call will crash.
} else if (nb::isinstance(elt, Operation)) {
auto op = nb::cast<MlirOperation>(elt);
associatedValues.push_back(unwrap(op));
// The following is probably preferable but is broken...
//} else if (nb::isinstance<MlirAttribute>(elt)) {
// If `elt` is of the wrong type, isinstance call will crash.
} else if (nb::isinstance(elt, Attribute)) {
auto param = nb::cast<MlirAttribute>(elt);
associatedValues.push_back(unwrap(param));
}
}
results.push_back(associatedValues);
}
}
return results;
};
});

mlir::tpp::registerTppCompilerPasses();
mlir::tpp::registerTppPassBundlePasses();
}
37 changes: 34 additions & 3 deletions python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
from ..._mlir_libs import get_dialect_registry
from ..._mlir_libs._tppDialects.transform.tune import register_dialect_extension
from ..._mlir_libs._tppDialects.transform.tune import (
register_dialect_extension,
register_callback_handler,
)
from ..._mlir_libs._tppDialects.transform import tune

tune._callback = None
from ..._mlir_libs import _tppDialects

_tppDialects._callback = lambda: print("callbacked")

register_dialect_extension(get_dialect_registry())

from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type
from .._tune_transform_ops_gen import TuneSelectOp
from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type, Operation, Value
from ...dialects import transform
from .._tune_transform_ops_gen import *

from collections.abc import Sequence
from typing import Union


def callback(
results: Type,
name: Union[str, Attribute],
*payloads: Union[
transform.AnyOpType, transform.AnyParamType, transform.AnyValueType
],
loc=None,
ip=None
):
if isinstance(name, str):
name = SymbolRefAttr.get([name])

return TuneCallbackOp(
results_=results,
name=name,
payloads=payloads,
loc=loc,
ip=ip,
)


def select(
selected: Type, # transform.any_param or transform.param<...>
name: Union[str, Attribute],
Expand Down
52 changes: 51 additions & 1 deletion python/mlir/tpp/sched/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable, Union, Dict

from mlir.dialects import transform
from mlir.dialects.transform import structured
from mlir.dialects.transform import structured, tune


# Wrapper to addresss verbosity.
Expand All @@ -10,3 +12,51 @@ def apply_registered_pass(*args, **kwargs):
# Wrapper to addresss verbosity.
def match(*args, **kwargs):
return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs)


# Global mapping callback names to Python-function callback functions.
HANDLER_MAPPING: Dict[str, Callable] = {}


# The python function that actually gets called from C++ to deal with
# transform.tune.callback callbacks.
def callback_handler(name, *args):
if (handler := HANDLER_MAPPING.get(name)) is None:
raise RuntimeError(f"callback '{name}' requested but was not registered")
return handler(*args)


tune.register_callback_handler(callback_handler)


# Decorator to register named Python callback functions. Return types need to be
# provided as part of the signature.
def callback(function: Callable):
if function.__name__ in HANDLER_MAPPING:
raise RuntimeError("tried to register a callback with the same name twice")
HANDLER_MAPPING[function.__name__] = function
results_type = function.__annotations__.get("return", ())

def wrapper(
*args: Union[
transform.AnyOpType, transform.AnyValueType, transform.AnyParamType
]
):
return transform.tune.callback(results_type, function.__name__, *args)

return wrapper


# Decorator to register named Python callback function and immediately call it.
# Return types need to be provided as part of the signature.
def call_with(
*args: Union[transform.AnyOpType, transform.AnyValueType, transform.AnyParamType]
):
def decorator(function: Callable):
if function.__name__ in HANDLER_MAPPING:
raise RuntimeError("tried to register a callback with the same name twice")
HANDLER_MAPPING[function.__name__] = function
results_type = function.__annotations__.get("return", ())
return transform.tune.callback(results_type, function.__name__, *args)

return decorator
Loading
Loading