diff --git a/examples/brutus.jl b/examples/brutus.jl index 9752f07e..4a7b0090 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -276,7 +276,7 @@ using MLIR.IR, MLIR fptr = IR.context!(IR.Context()) do op = Brutus.code_mlir(pow, Tuple{Int,Int}) - mod = MModule(Location()) + mod = IR.Module(Location()) body = IR.get_body(mod) push!(body, op) diff --git a/src/IR/IR.jl b/src/IR/IR.jl index a3258ba4..bfd83aeb 100644 --- a/src/IR/IR.jl +++ b/src/IR/IR.jl @@ -3,7 +3,6 @@ export OperationState, Location, Context, - MModule, Value, MLIRType, Region, @@ -749,33 +748,34 @@ Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region ### Module -mutable struct MModule +mutable struct Module module_::MlirModule - MModule(module_) = begin - @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" + Module(module_) = begin + @assert !mlirIsNull(module_) "cannot create Module with null MlirModule" finalizer(API.mlirModuleDestroy, new(module_)) end end -MModule(loc::Location=Location()) = - MModule(API.mlirModuleCreateEmpty(loc)) +Module(loc::Location=Location()) = + Module(API.mlirModuleCreateEmpty(loc)) +Module(op::Operation) = Module(API.mlirModuleFromOperation(lose_ownership!(op))) get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) get_body(module_) = Block(API.mlirModuleGetBody(module_), false) -get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) +get_first_child_op(mod::Module) = get_first_child_op(get_operation(mod)) -Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ -Base.parse(::Type{MModule}, module_) = MModule(API.mlirModuleCreateParse(context(), module_), context()) +Base.convert(::Type{MlirModule}, module_::Module) = module_.module_ +Base.parse(::Type{Module}, module_) = Module(API.mlirModuleCreateParse(context(), module_), context()) macro mlir_str(code) quote ctx = Context() - parse(MModule, code) + parse(Module, code) end end -function Base.show(io::IO, module_::MModule) - println(io, "MModule:") +function Base.show(io::IO, module_::Module) + println(io, "Module:") show(io, get_operation(module_)) end diff --git a/src/IR/Support.jl b/src/IR/Support.jl index f84689e3..6eb3dceb 100644 --- a/src/IR/Support.jl +++ b/src/IR/Support.jl @@ -129,5 +129,5 @@ function verifyall(operation::Operation; debug=false) end end end -verifyall(module_::MModule) = get_operation(module_) |> verifyall +verifyall(module_::IR.Module) = get_operation(module_) |> verifyall diff --git a/test/ir.jl b/test/ir.jl index 8d47760c..f6594638 100644 --- a/test/ir.jl +++ b/test/ir.jl @@ -1,4 +1,4 @@ -using MLIR.Dialects: arith +using MLIR.Dialects: arith, builtin using MLIR.IR, LLVM @testset "operation introspection" begin @@ -11,3 +11,20 @@ using MLIR.IR, LLVM @test IR.get_attribute_by_name(op, "value") |> IR.bool_value end end + +@testset "Module construction from operation" begin + IR.context!(IR.Context()) do + if LLVM.version() >= v"15" + op = builtin.module_(bodyRegion=IR.Region()) + else + op = builtin.module_(body=IR.Region()) + end + mod = IR.Module(op) + op = IR.get_operation(mod) + + @test IR.name(op) == "builtin.module" + + # Only a `module` operation can be used to create a module. + @test_throws AssertionError IR.Module(arith.constant(; value=true, result=MLIRType(Bool))) + end +end