diff --git a/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir index c5d86a60d7a8..3762ac9d4568 100644 --- a/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir @@ -47,7 +47,7 @@ stream.executable public @add_dispatch_0 { // CHECK-SAME: interface = @io, // CHECK-SAME: ordinal = 0 : index // CHECK-SAME: } -// CHECK: module { +// CHECK: module attributes {vm.toplevel} { // CHECK-NEXT: vm.module public @module { // CHECK-NEXT: vm.func private @add_dispatch_0( // CHECK-SAME: %[[SCRATCHPAD:.+]]: !vm.buffer, %[[CONSTANTS:.+]]: !vm.buffer, diff --git a/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp index 6ff48916b793..1828e9da601e 100644 --- a/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp +++ b/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp @@ -24,18 +24,25 @@ VMConversionTarget::nestModuleForConversion(mlir::ModuleOp outerModuleOp) { outerModuleOp.getBodyRegion().getBlocks().push_back(new Block()); outerModuleOp.push_back(innerModuleOp); } + + outerModuleOp->setAttr("vm.toplevel", + UnitAttr::get(outerModuleOp.getContext())); return std::make_pair(outerModuleOp, innerModuleOp); } +// static +bool VMConversionTarget::isTopLevelModule(mlir::ModuleOp moduleOp) { + return !moduleOp->getParentOp() || moduleOp->hasAttr("vm.toplevel"); +} + VMConversionTarget::VMConversionTarget(MLIRContext *context) : ConversionTarget(*context) { addLegalDialect(); // NOTE: we need to allow the outermost std.module to be legal to support the // double-nesting (module { vm.module { ... } }). - addDynamicallyLegalOp(+[](mlir::ModuleOp op) { - return !op->getParentOp() || !isa(op->getParentOp()); - }); + addDynamicallyLegalOp( + +[](mlir::ModuleOp op) { return isTopLevelModule(op); }); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/VM/Conversion/ConversionTarget.h b/iree/compiler/Dialect/VM/Conversion/ConversionTarget.h index 57884559fdde..0b6782f02f64 100644 --- a/iree/compiler/Dialect/VM/Conversion/ConversionTarget.h +++ b/iree/compiler/Dialect/VM/Conversion/ConversionTarget.h @@ -27,10 +27,14 @@ class VMConversionTarget : public ConversionTarget { // Example: // module { func @foo() { ... } } // -> - // module { module { func @foo() { ... } } } + // module attributes {vm.toplevel} { module { func @foo() { ... } } } static std::pair nestModuleForConversion( mlir::ModuleOp outerModuleOp); + // Returns whether this is the outer module as setup via + // nestModuleForConversion. Use for patterns which need to distinguish. + static bool isTopLevelModule(mlir::ModuleOp moduleOp); + VMConversionTarget(MLIRContext *context); }; diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp index 78e1a3599589..a1d28f019715 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h" #include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h" #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" @@ -32,7 +33,7 @@ class ModuleOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Do not attempt to convert the top level module. // This mechanism can only support rewriting non top-level modules. - if (!srcOp->getParentOp() || !isa(srcOp->getParentOp())) { + if (VMConversionTarget::isTopLevelModule(srcOp)) { return failure(); } diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD index 6fcb0b66641c..3f5f15d8b427 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD @@ -24,6 +24,7 @@ iree_lit_test_suite( "control_flow_ops.mlir", "conversion_ops.mlir", "func_attrs.mlir", + "nesting.mlir", "structural_ops.mlir", ], include = ["*.mlir"], diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt index 2aa352e6b95c..a09bece21892 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "control_flow_ops.mlir" "conversion_ops.mlir" "func_attrs.mlir" + "nesting.mlir" "structural_ops.mlir" DATA iree::tools::IreeFileCheck diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/nesting.mlir b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/nesting.mlir new file mode 100644 index 000000000000..3e07cf63706c --- /dev/null +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/nesting.mlir @@ -0,0 +1,21 @@ +// RUN: iree-opt -split-input-file -pass-pipeline='builtin.module(test-iree-convert-std-to-vm)' %s | IreeFileCheck %s + +// Note that checks are ambiguous between "module" and "vm.module" so we rely +// on vm.module printing as `vm.module public @foo` + +// CHECK-LABEL: module @outerBuiltinModule +module @outerBuiltinModule { + // CHECK-NEXT: module @innerBuiltinModule attributes {vm.toplevel} + module @innerBuiltinModule attributes {vm.toplevel} { + // CHECK-NEXT: vm.module public @outerVmModule + module @outerVmModule { + // CHECK-NEXT: vm.module public @deeplyNested + module @deeplyNested { + // CHECK: vm.func private @foo + func @foo() { + return + } + } + } + } +} diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/structural_ops.mlir b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/structural_ops.mlir index a062c55d9ba3..40e84f1900ed 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/structural_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/structural_ops.mlir @@ -40,22 +40,3 @@ module { } } - -// ----- - -// CHECK: module -module { - // CHECK: module - module { - // CHECK: module - module { - // CHECK-LABEL: vm.module public @deeplyNested - module @deeplyNested { - // CHECK: vm.func private @foo - func @foo() { - return - } - } - } - } -}