Skip to content

Commit ef57e38

Browse files
committed
check verification
1 parent 2decc8b commit ef57e38

File tree

5 files changed

+95
-12
lines changed

5 files changed

+95
-12
lines changed

lib/Dialect/HEIRInterfaces.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ void registerOperandAndResultAttrInterface(DialectRegistry &registry) {
1919

2020
void registerLayoutConversionHoistableInterface(DialectRegistry &registry) {
2121
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
22-
arith::AddIOp::attachInterface<DoNothingHoistingImpl<arith::AddIOp>>(*ctx);
2322
arith::AddFOp::attachInterface<DoNothingHoistingImpl<arith::AddFOp>>(*ctx);
23+
arith::AddIOp::attachInterface<DoNothingHoistingImpl<arith::AddIOp>>(*ctx);
24+
arith::MulFOp::attachInterface<DoNothingHoistingImpl<arith::MulFOp>>(*ctx);
25+
arith::MulIOp::attachInterface<DoNothingHoistingImpl<arith::MulIOp>>(*ctx);
26+
arith::SubFOp::attachInterface<DoNothingHoistingImpl<arith::SubFOp>>(*ctx);
27+
arith::SubIOp::attachInterface<DoNothingHoistingImpl<arith::SubIOp>>(*ctx);
2428
});
2529
}
2630

lib/Dialect/HEIRInterfaces.td

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,21 @@ def LayoutConversionHoistableOpInterface : OpInterface<"LayoutConversionHoistabl
3131
>,
3232
];
3333

34-
let verify = [{
35-
auto attrName = ::mlir::heir::secret::SecretDialect::kKernelAttrName;
36-
if ($_op->hasAttr(attrName)) {
37-
auto attr = $_op->getAttrOfType<::mlir::heir::secret::KernelAttr>(attrName);
38-
return ::mlir::heir::isSupportedKernel($_op, attr.getName()) ? success() : failure();
39-
}
40-
return success();
41-
}];
34+
// TODO(#1888): figure out how to get OpInterface verifier to run
35+
// automatically.
36+
// let verify = [{
37+
// ::mlir::heir::KernelName kernelName = ::mlir::heir::KernelName::Trivial;
38+
// auto attrName = ::mlir::heir::secret::SecretDialect::kKernelAttrName;
39+
// if ($_op->hasAttr(attrName)) {
40+
// auto attr = $_op->getAttrOfType<::mlir::heir::secret::KernelAttr>(attrName);
41+
// kernelName = attr.getName();
42+
// }
43+
// if (!::mlir::heir::isSupportedKernel($_op, kernelName)) {
44+
// return $_op->emitOpError()
45+
// << "has unsupported kernel '" << kernelName << "'";
46+
// }
47+
// return success();
48+
// }];
4249
}
4350

4451
def OperandAndResultAttrInterface : OpInterface<"OperandAndResultAttrInterface"> {

lib/Kernel/Kernel.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#include "Kernel.h"
22

3+
#include <set>
34
#include <string>
45
#include <unordered_map>
56

6-
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
7-
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
8-
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
7+
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
910

1011
namespace mlir {
1112
namespace heir {
@@ -15,9 +16,20 @@ static std::unordered_map<KernelName, std::string> correspondingOp = {
1516
{KernelName::MatvecNaive, "linalg.matvec"},
1617
{KernelName::MatvecDiagonal, "linalg.matvec"},
1718
};
19+
20+
std::set<std::string> requiredNontrivial = {"linalg"};
1821
} // namespace
1922

2023
bool isSupportedKernel(Operation *op, KernelName name) {
24+
std::string dialect = std::string(op->getDialect()->getNamespace());
25+
if (name == KernelName::Trivial) {
26+
return requiredNontrivial.count(dialect) == 0;
27+
}
28+
29+
if (correspondingOp.find(name) == correspondingOp.end()) {
30+
return false;
31+
}
32+
2133
std::string actual;
2234
llvm::raw_string_ostream ss(actual);
2335
ss << op->getDialect()->getNamespace() << "." << op->getName().getStringRef();

lib/Transforms/LayoutOptimization/LayoutOptimization.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,26 @@ void LayoutOptimization::runOnOperation() {
9393
WalkResult result =
9494
getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
9595
[&](Operation *op) {
96+
if (auto hoistable =
97+
dyn_cast<LayoutConversionHoistableOpInterface>(op)) {
98+
// TODO(#1888): figure out how to get OpInterface verifier to run
99+
// automatically.
100+
KernelName kernelName = KernelName::Trivial;
101+
auto attrName =
102+
::mlir::heir::secret::SecretDialect::kKernelAttrName;
103+
if (op->hasAttr(attrName)) {
104+
kernelName =
105+
op->getAttrOfType<::mlir::heir::secret::KernelAttr>(
106+
attrName)
107+
.getName();
108+
}
109+
110+
if (!::mlir::heir::isSupportedKernel(op, kernelName)) {
111+
op->emitOpError() << "has unsupported kernel\n";
112+
return WalkResult::interrupt();
113+
}
114+
}
115+
96116
// Attempt to hoist layout conversions before this operation.
97117
OpHoistResult result = hoistOp(op, builder);
98118
if (result == FAILURE) {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: heir-opt --split-input-file --layout-optimization --verify-diagnostics %s
2+
3+
// Valid
4+
func.func @main(%arg0: tensor<512xf32>) -> tensor<512xf32> {
5+
%0 = arith.addf %arg0, %arg0 {secret.kernel = #secret.kernel<name="Trivial", force=false>} : tensor<512xf32>
6+
func.return %0 : tensor<512xf32>
7+
}
8+
9+
// -----
10+
11+
// Bad kernel name
12+
func.func @main(%arg0: tensor<512xf32>) -> tensor<512xf32> {
13+
// expected-error@below {{has unsupported kernel}}
14+
%0 = arith.addf %arg0, %arg0 {secret.kernel = #secret.kernel<name="MatvecNaive", force=false>} : tensor<512xf32>
15+
func.return %0 : tensor<512xf32>
16+
}
17+
18+
// -----
19+
20+
// Good kernel name
21+
func.func @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512xf32>) -> tensor<512xf32> {
22+
%cst = tensor.empty() : tensor<512xf32>
23+
%0 = linalg.matvec
24+
{secret.kernel = #secret.kernel<name="MatvecDiagonal", force=false>}
25+
ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512xf32>)
26+
outs(%cst : tensor<512xf32>) -> tensor<512xf32>
27+
func.return %0 : tensor<512xf32>
28+
}
29+
// -----
30+
31+
// TODO(#1888): re-enable when matvec is hoistable
32+
// Missing required kernel
33+
// func.func @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512xf32>) -> tensor<512xf32> {
34+
// %cst = tensor.empty() : tensor<512xf32>
35+
// // expected-error@REENABLEME {{has unsupported kernel}}
36+
// %0 = linalg.matvec
37+
// ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512xf32>)
38+
// outs(%cst : tensor<512xf32>) -> tensor<512xf32>
39+
// func.return %0 : tensor<512xf32>
40+
// }

0 commit comments

Comments
 (0)