Skip to content

Commit

Permalink
Erase shape_assertion ops (#18167)
Browse files Browse the repository at this point in the history
Just dropping shape_assertion custom call ops.
  • Loading branch information
jpienaar authored Aug 9, 2024
1 parent 6f88125 commit df3d588
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,21 @@ struct HouseholderReflectorRewriter final
}
};

struct ShapeAssertionDrop final
: OpRewritePattern<mlir::stablehlo::CustomCallOp> {
using OpRewritePattern<mlir::stablehlo::CustomCallOp>::OpRewritePattern;
using OpAdaptor = mlir::stablehlo::CustomCallOp::Adaptor;

LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op,
PatternRewriter &rewriter) const final {
if (op.getCallTargetName() != "shape_assertion") {
return rewriter.notifyMatchFailure(op, "not shape_assertion");
}
rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass Definition.
//===----------------------------------------------------------------------===//
Expand All @@ -237,7 +252,7 @@ struct LegalizeStableHLOCustomCalls final
MLIRContext *ctx = f.getContext();

RewritePatternSet patterns(ctx);
patterns.add<HouseholderReflectorRewriter>(ctx);
patterns.add<HouseholderReflectorRewriter, ShapeAssertionDrop>(ctx);
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
4 changes: 4 additions & 0 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ ALL_SRCS = enforce_glob(
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"shape_assertion.mlir",
"sine.mlir",
"slice.mlir",
"sort.mlir",
Expand Down Expand Up @@ -169,6 +170,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"shape_assertion.mlir",
"sine.mlir",
"slice.mlir",
"sort.mlir",
Expand Down Expand Up @@ -247,6 +249,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"shape_assertion.mlir",
"sine.mlir",
"slice.mlir",
"sort.mlir",
Expand Down Expand Up @@ -381,6 +384,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"shape_assertion.mlir",
"sine.mlir",
"slice.mlir",
"sort.mlir",
Expand Down
10 changes: 10 additions & 0 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -142,6 +143,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -219,6 +221,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -291,6 +294,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -366,6 +370,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -450,6 +455,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -533,6 +539,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -607,6 +614,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -680,6 +688,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down Expand Up @@ -754,6 +763,7 @@ iree_check_single_backend_test_suite(
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"shape_assertion.mlir"
"sine.mlir"
"slice.mlir"
"sort.mlir"
Expand Down
9 changes: 9 additions & 0 deletions tests/e2e/stablehlo_ops/shape_assertion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
func.func @tensor() {
%0 = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
%1 = util.unfoldable_constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf32>
%4 = stablehlo.compare EQ, %0, %1, NOTYPE : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<4xi1>) -> ()
%result = "stablehlo.add"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
check.expect_almost_eq_const(%result, dense<[6.0, 8.0, 10.0, 12.0]> : tensor<4xf32>) : tensor<4xf32>
return
}

0 comments on commit df3d588

Please sign in to comment.