From d239921eb2d372d89d6d4aa1e32ddfa11842d6d2 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 2 Oct 2023 14:44:15 +0200 Subject: [PATCH] add LLVM.jl-like context state handling. (#17) --- examples/brutus.jl | 63 +++++++------- src/Dialects.jl | 50 +++++------ src/IR/IR.jl | 207 +++++++++++++++++++++++---------------------- src/IR/Pass.jl | 108 ++++++++++++----------- src/IR/state.jl | 43 ++++++++++ src/MLIR.jl | 9 +- 6 files changed, 269 insertions(+), 211 deletions(-) create mode 100644 src/IR/state.jl diff --git a/examples/brutus.jl b/examples/brutus.jl index 36df2877..bba5835b 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -8,13 +8,13 @@ using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} function cmpi_pred(predicate) - function(ctx, ops; loc=Location(ctx)) - arith.cmpi(ctx, predicate, ops; loc) + function(ops; loc=Location()) + arith.cmpi(predicate, ops; loc) end end function single_op_wrapper(fop) - (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) + (block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc)) end const intrinsics_to_mlir = Dict([ @@ -24,15 +24,15 @@ const intrinsics_to_mlir = Dict([ Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), Base.mul_int => single_op_wrapper(arith.muli), Base.mul_float => single_op_wrapper(arith.mulf), - Base.not_int => function(ctx, block, args; loc=Location(ctx)) + Base.not_int => function(block, args; loc=Location()) arg = only(args) - ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result - push!(block, arith.xori(ctx, Value[arg, ones]; loc)) + ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(Value[arg, ones]; loc)) end, ]) "Generates a block argument for each phi node present in the block." -function prepare_block(ctx, ir, bb) +function prepare_block(ir, bb) b = Block() for sidx in bb.stmts @@ -41,7 +41,7 @@ function prepare_block(ctx, ir, bb) inst isa Core.PhiNode || continue type = stmt[:type] - IR.push_argument!(b, MLIRType(ctx, type), Location(ctx)) + IR.push_argument!(b, MLIRType(type), Location()) end return b @@ -68,7 +68,7 @@ function collect_value_arguments(ir, from, to) end """ - code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation + code_mlir(f, types::Type{Tuple}) -> IR.Operation Returns a `func.func` operation corresponding to the ircode of the provided method. This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. @@ -78,25 +78,26 @@ This only supports a few Julia Core primitives and scalar types of type $BrutusS handful of primitives. A better to perform this conversion would to create a dialect representing Julia IR and progressively lower it to base MLIR dialects. """ -function code_mlir(f, types; ctx=Context()) +function code_mlir(f, types) + ctx = context() ir, ret = Core.Compiler.code_ircode(f, types) |> only @assert first(ir.argtypes) isa Core.Const values = Vector{Value}(undef, length(ir.stmts)) for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",)) - IR.get_or_load_dialect!(ctx, dialect) + IR.get_or_load_dialect!(dialect) end blocks = [ - prepare_block(ctx, ir, bb) + prepare_block(ir, bb) for bb in ir.cfg.blocks ] current_block = entry_block = blocks[begin] for argtype in types.parameters - IR.push_argument!(entry_block, MLIRType(ctx, argtype), Location(ctx)) + IR.push_argument!(entry_block, MLIRType(argtype), Location()) end function get_value(x)::Value @@ -106,7 +107,7 @@ function code_mlir(f, types; ctx=Context()) elseif x isa Core.Argument IR.get_argument(entry_block, x.n - 1) elseif x isa BrutusScalar - IR.get_result(push!(current_block, arith.constant(ctx, x))) + IR.get_result(push!(current_block, arith.constant(x))) else error("could not use value $x inside MLIR") end @@ -126,7 +127,7 @@ function code_mlir(f, types; ctx=Context()) if !(val_type <: BrutusScalar) error("type $val_type is not supported") end - out_type = MLIRType(ctx, val_type) + out_type = MLIRType(val_type) called_func = first(inst.args) if called_func isa GlobalRef # TODO: should probably use something else here @@ -136,8 +137,8 @@ function code_mlir(f, types; ctx=Context()) fop! = intrinsics_to_mlir[called_func] args = get_value.(@view inst.args[begin+1:end]) - loc = Location(ctx, string(line.file), line.line, 0) - res = IR.get_result(fop!(ctx, current_block, args; loc)) + loc = Location(string(line.file), line.line, 0) + res = IR.get_result(fop!(current_block, args; loc)) values[sidx] = res elseif inst isa PhiNode @@ -147,9 +148,9 @@ function code_mlir(f, types; ctx=Context()) elseif inst isa GotoNode args = get_value.(collect_value_arguments(ir, block_id, inst.label)) dest = blocks[inst.label] - loc = Location(ctx, string(line.file), line.line, 0) + loc = Location(string(line.file), line.line, 0) brop = LLVM.version() >= v"15" ? cf.br : std.br - push!(current_block, brop(ctx, dest, args; loc)) + push!(current_block, brop(dest, args; loc)) elseif inst isa GotoIfNot false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) cond = get_value(inst.cond) @@ -159,15 +160,15 @@ function code_mlir(f, types; ctx=Context()) other_dest = blocks[other_dest] dest = blocks[inst.dest] - loc = Location(ctx, string(line.file), line.line, 0) + loc = Location(string(line.file), line.line, 0) cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br - cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc) + cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc) push!(current_block, cond_br) elseif inst isa ReturnNode line = ir.linetable[stmt[:line]] retop = LLVM.version() >= v"15" ? func.return_ : std.return_ - loc = Location(ctx, string(line.file), line.line, 0) - push!(current_block, retop(ctx, [get_value(inst.val)]; loc)) + loc = Location(string(line.file), line.line, 0) + push!(current_block, retop([get_value(inst.val)]; loc)) elseif Meta.isexpr(inst, :code_coverage_effect) # Skip else @@ -189,15 +190,15 @@ function code_mlir(f, types; ctx=Context()) IR.get_type(IR.get_argument(entry_block, i)) for i in 1:IR.num_arguments(entry_block) ] - result_types = [MLIRType(ctx, ret)] + result_types = [MLIRType(ret)] - ftype = MLIRType(ctx, input_types => result_types) + ftype = MLIRType(input_types => result_types) op = IR.create_operation( LLVM15 ? "func.func" : "builtin.func", - Location(ctx); + Location(); attributes = [ - NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), - NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), + NamedAttribute("sym_name", IR.Attribute(string(func_name))), + NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), ], owned_regions = Region[region], result_inference=false, @@ -254,13 +255,13 @@ using MLIR.IR, MLIR ctx = Context() # IR.enable_multithreading!(ctx, false) -op = Brutus.code_mlir(pow, Tuple{Int, Int}; ctx) +op = Brutus.code_mlir(pow, Tuple{Int, Int}) -mod = MModule(ctx, Location(ctx)) +mod = MModule(Location()) body = IR.get_body(mod) push!(body, op) -pm = IR.PassManager(ctx) +pm = IR.PassManager() opm = IR.OpPassManager(pm) # IR.enable_ir_printing!(pm) diff --git a/src/Dialects.jl b/src/Dialects.jl index 4cb400eb..cd6f4244 100644 --- a/src/Dialects.jl +++ b/src/Dialects.jl @@ -9,13 +9,13 @@ for (f, t) in Iterators.product( (:i, :f), ) fname = Symbol(f, t) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end for fname in (:xori, :andi, :ori) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end @@ -25,37 +25,37 @@ for (f, t) in Iterators.product( (:si, :ui, :f), ) fname = Symbol(f, t) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop for f in (:index_cast, :index_castui) - @eval function $f(context, operand; loc=Location(context)) + @eval function $f(operand; loc=Location()) IR.create_operation( $(string("arith.", f)), loc; operands=[operand], - results=[IR.IndexType(context)], + results=[IR.IndexType()], ) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop -function extf(context, operand, type; loc=Location(context)) +function extf(operand, type; loc=Location()) IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop -function constant(context, value, type=MLIRType(context, typeof(value)); loc=Location(context)) +function constant(value, type=MLIRType(typeof(value)); loc=Location()) IR.create_operation( "arith.constant", loc; results=[type], attributes=[ - IR.NamedAttribute(context, "value", - Attribute(context, value, type)), + IR.NamedAttribute("value", + Attribute(value, type)), ], ) end @@ -73,15 +73,15 @@ module Predicates const uge = 9 end -function cmpi(context, predicate, operands; loc=Location(context)) +function cmpi(predicate, operands; loc=Location()) IR.create_operation( "arith.cmpi", loc; operands, - results=[MLIRType(context, Bool)], + results=[MLIRType(Bool)], attributes=[ - IR.NamedAttribute(context, "predicate", - Attribute(context, predicate)) + IR.NamedAttribute("predicate", + Attribute(predicate)) ], ) end @@ -93,20 +93,20 @@ module std using ...IR -function return_(context, operands; loc=Location(context)) +function return_(operands; loc=Location()) IR.create_operation("std.return", loc; operands, result_inference=false) end -function br(context, dest, operands; loc=Location(context)) +function br(dest, operands; loc=Location()) IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( - context, cond, + cond, true_dest, false_dest, true_dest_operands, false_dest_operands; - loc=Location(context), + loc=Location(), ) IR.create_operation( "std.cond_br", @@ -114,8 +114,8 @@ function cond_br( successors=[true_dest, false_dest], operands=[cond, true_dest_operands..., false_dest_operands...], attributes=[ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) ], result_inference=false, ) @@ -128,7 +128,7 @@ module func using ...IR -function return_(context, operands; loc=Location(context)) +function return_(operands; loc=Location()) IR.create_operation("func.return", loc; operands, result_inference=false) end @@ -138,24 +138,24 @@ module cf using ...IR -function br(context, dest, operands; loc=Location(context)) +function br(dest, operands; loc=Location()) IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( - context, cond, + cond, true_dest, false_dest, true_dest_operands, false_dest_operands; - loc=Location(context), + loc=Location(), ) IR.create_operation( "cf.cond_br", loc; operands=[cond, true_dest_operands..., false_dest_operands...], successors=[true_dest, false_dest], attributes=[ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) ], result_inference=false, ) diff --git a/src/IR/IR.jl b/src/IR/IR.jl index 3606d329..21f4f7ff 100644 --- a/src/IR/IR.jl +++ b/src/IR/IR.jl @@ -1,7 +1,3 @@ -module IR - -import ..API: API - export Operation, OperationState, @@ -90,38 +86,53 @@ end ### Context -mutable struct Context +struct Context context::MlirContext end + function Context() context = API.mlirContextCreate() @assert !mlirIsNull(context) "cannot create Context with null MlirContext" - finalizer(Context(context)) do context - API.mlirContextDestroy(context.context) + context = Context(context) + activate(context) + context +end + +function dispose(ctx::Context) + deactivate(ctx) + API.mlirContextDestroy(ctx.context) +end + +function Context(f::Core.Function) + ctx = Context() + try + f(ctx) + finally + dispose(ctx) end end Base.convert(::Type{MlirContext}, c::Context) = c.context -num_loaded_dialects(context) = API.mlirContextGetNumLoadedDialects(context) -function get_or_load_dialect!(context, handle::DialectHandle) - mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context) +num_loaded_dialects() = API.mlirContextGetNumLoadedDialects(context()) +function get_or_load_dialect!(handle::DialectHandle) + mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context()) if mlirIsNull(mlir_dialect) error("could not load dialect from handle $handle") else Dialect(mlir_dialect) end end -function get_or_load_dialect!(context, dialect::String) - get_or_load_dialect!(context, DialectHandle(Symbol(dialect))) +function get_or_load_dialect!(dialect::String) + get_or_load_dialect!(DialectHandle(Symbol(dialect))) end -function enable_multithreading!(context, enable=true) - API.mlirContextEnableMultithreading(context, enable) - context +function enable_multithreading!(enable=true) + API.mlirContextEnableMultithreading(context(), enable) + context() end -is_registered_operation(context, opname) = API.mlirContextIsRegisteredOperation(context, opname) +is_registered_operation(opname) = API.mlirContextIsRegisteredOperation(context(), opname) ### Location @@ -134,9 +145,9 @@ struct Location end end -Location(context::Context) = Location(API.mlirLocationUnknownGet(context)) -Location(context::Context, filename, line, column) = - Location(API.mlirLocationFileLineColGet(context, filename, line, column)) +Location() = Location(API.mlirLocationUnknownGet(context())) +Location(filename, line, column) = + Location(API.mlirLocationFileLineColGet(context(), filename, line, column)) Base.convert(::Type{MlirLocation}, location::Location) = location.location @@ -160,43 +171,43 @@ struct MLIRType end MLIRType(t::MLIRType) = t -MLIRType(context::Context, T::Type{<:Signed}) = - MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MLIRType(context::Context, T::Type{<:Unsigned}) = - MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MLIRType(context::Context, ::Type{Bool}) = - MLIRType(API.mlirIntegerTypeGet(context, 1)) -MLIRType(context::Context, ::Type{Float32}) = - MLIRType(API.mlirF32TypeGet(context)) -MLIRType(context::Context, ::Type{Float64}) = - MLIRType(API.mlirF64TypeGet(context)) -MLIRType(context::Context, ft::Pair) = - MLIRType(API.mlirFunctionTypeGet(context, +MLIRType(T::Type{<:Signed}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(T::Type{<:Unsigned}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(::Type{Bool}) = + MLIRType(API.mlirIntegerTypeGet(context(), 1)) +MLIRType(::Type{Float32}) = + MLIRType(API.mlirF32TypeGet(context())) +MLIRType(::Type{Float64}) = + MLIRType(API.mlirF64TypeGet(context())) +MLIRType(ft::Pair) = + MLIRType(API.mlirFunctionTypeGet(context(), length(ft.first), [MLIRType(t) for t in ft.first], length(ft.second), [MLIRType(t) for t in ft.second])) -MLIRType(context, a::AbstractArray{T}) where {T} = MLIRType(context, MLIRType(context, T), size(a)) -MLIRType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = +MLIRType(a::AbstractArray{T}) where {T} = MLIRType(MLIRType(T), size(a)) +MLIRType(::Type{<:AbstractArray{T,N}}, dims) where {T,N} = MLIRType(API.mlirRankedTensorTypeGetChecked( - Location(context), + Location(), N, collect(dims), - MLIRType(context, T), + MLIRType(T), Attribute(), )) -MLIRType(context, element_type::MLIRType, dims) = +MLIRType(element_type::MLIRType, dims) = MLIRType(API.mlirRankedTensorTypeGetChecked( - Location(context), + Location(), length(dims), collect(dims), element_type, Attribute(), )) -MLIRType(context, ::T) where {T<:Real} = MLIRType(context, T) +MLIRType(::T) where {T<:Real} = MLIRType(T) MLIRType(_, type::MLIRType) = type -IndexType(context) = MLIRType(API.mlirIndexTypeGet(context)) +IndexType() = MLIRType(API.mlirIndexTypeGet(context())) Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type -Base.parse(::Type{MLIRType}, context, s) = - MLIRType(API.mlirTypeParseGet(context, s)) +Base.parse(::Type{MLIRType}, s) = + MLIRType(API.mlirTypeParseGet(context(), s)) function Base.eltype(type::MLIRType) if API.mlirTypeIsAShaped(type) @@ -246,10 +257,10 @@ function Base.show(io::IO, type::MLIRType) end function inttype(size, issigned) - size == 1 && issigned && return Bool - ints = (Int8, Int16, Int32, Int64, Int128) - IT = ints[Int(log2(size)) - 2] - issigned ? IT : unsigned(IT) + size == 1 && issigned && return Bool + ints = (Int8, Int16, Int32, Int64, Int128) + IT = ints[Int(log2(size))-2] + issigned ? IT : unsigned(IT) end function julia_type(type::MLIRType) @@ -318,83 +329,83 @@ struct Attribute end Attribute() = Attribute(API.mlirAttributeGetNull()) -Attribute(context, s::AbstractString) = Attribute(API.mlirStringAttrGet(context, s)) +Attribute(s::AbstractString) = Attribute(API.mlirStringAttrGet(context(), s)) Attribute(type::MLIRType) = Attribute(API.mlirTypeAttrGet(type)) -Attribute(context, f::F, type=MLIRType(context, F)) where {F<:AbstractFloat} = Attribute( - API.mlirFloatAttrDoubleGet(context, type, Float64(f)) +Attribute(f::F, type=MLIRType(F)) where {F<:AbstractFloat} = Attribute( + API.mlirFloatAttrDoubleGet(context(), type, Float64(f)) ) -Attribute(context, i::T) where {T<:Integer} = Attribute( - API.mlirIntegerAttrGet(MLIRType(context, T), Int64(i)) +Attribute(i::T) where {T<:Integer} = Attribute( + API.mlirIntegerAttrGet(MLIRType(T), Int64(i)) ) -function Attribute(context, values::T) where {T<:AbstractArray{Int32}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Int32}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Int64}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Int64}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Float64}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Float64}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrDoubleGet(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Float32}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Float32}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Int32}, type) +function Attribute(values::AbstractArray{Int32}, type) Attribute( API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Int}, type) +function Attribute(values::AbstractArray{Int}, type) Attribute( API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Float32}, type) +function Attribute(values::AbstractArray{Float32}, type) Attribute( API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end -function ArrayAttribute(context, values::AbstractVector{Int}) - elements = Attribute.((context,), values) +function ArrayAttribute(values::AbstractVector{Int}) + elements = Attribute.(values) Attribute( - API.mlirArrayAttrGet(context, length(elements), elements) + API.mlirArrayAttrGet(context(), length(elements), elements) ) end -function ArrayAttribute(context, attributes::Vector{Attribute}) +function ArrayAttribute(attributes::Vector{Attribute}) Attribute( - API.mlirArrayAttrGet(context, length(attributes), attributes), + API.mlirArrayAttrGet(context(), length(attributes), attributes), ) end -function DenseArrayAttribute(context, values::AbstractVector{Int}) +function DenseArrayAttribute(values::AbstractVector{Int}) Attribute( - API.mlirDenseI64ArrayGet(context, length(values), collect(values)) + API.mlirDenseI64ArrayGet(context(), length(values), collect(values)) ) end -function Attribute(context, value::Int, type::MLIRType) +function Attribute(value::Int, type::MLIRType) Attribute( API.mlirIntegerAttrGet(type, value) ) end -function Attribute(context, value::Bool, ::MLIRType=nothing) +function Attribute(value::Bool, ::MLIRType=nothing) Attribute( - API.mlirBoolAttrGet(context, value) + API.mlirBoolAttrGet(context(), value) ) end Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute -Base.parse(::Type{Attribute}, context, s) = - Attribute(API.mlirAttributeParseGet(context, s)) +Base.parse(::Type{Attribute}, s) = + Attribute(API.mlirAttributeParseGet(context(), s)) function get_type(attribute::Attribute) MLIRType(API.mlirAttributeGetType(attribute)) @@ -426,10 +437,10 @@ struct NamedAttribute named_attribute::MlirNamedAttribute end -function NamedAttribute(context, name, attribute) +function NamedAttribute(name, attribute) @assert !mlirIsNull(attribute.attribute) NamedAttribute(API.mlirNamedAttributeGet( - API.mlirIdentifierGet(context, name), + API.mlirIdentifierGet(context(), name), attribute )) end @@ -510,7 +521,7 @@ function create_operation( owned_regions=nothing, successors=nothing, attributes=nothing, - result_inference=isnothing(results), + result_inference=isnothing(results) ) GC.@preserve name loc begin state = Ref(API.mlirOperationStateGet(name, loc)) @@ -534,9 +545,9 @@ function create_operation( GC.@preserve successors begin mlir_blocks = Base.unsafe_convert.(MlirBlock, successors) API.mlirOperationStateAddSuccessors( - state, - length(mlir_blocks), - mlir_blocks, + state, + length(mlir_blocks), + mlir_blocks, ) end end @@ -761,27 +772,26 @@ Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region mutable struct MModule module_::MlirModule - context::Context - MModule(module_, context) = begin + MModule(module_) = begin @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" - finalizer(API.mlirModuleDestroy, new(module_, context)) + finalizer(API.mlirModuleDestroy, new(module_)) end end -MModule(context::Context, loc=Location(context)) = - MModule(API.mlirModuleCreateEmpty(loc), context) +MModule(loc::Location=Location()) = + MModule(API.mlirModuleCreateEmpty(loc)) get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) get_body(module_) = Block(API.mlirModuleGetBody(module_), false) get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ -Base.parse(::Type{MModule}, context, module_) = MModule(API.mlirModuleCreateParse(context, module_), context) +Base.parse(::Type{MModule}, module_) = MModule(API.mlirModuleCreateParse(context(), module_), context()) macro mlir_str(code) quote ctx = Context() - parse(MModule, ctx, code) + parse(MModule, code) end end @@ -801,30 +811,29 @@ Base.convert(::Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid @static if isdefined(API, :MlirTypeIDAllocator) -### TypeIDAllocator + ### TypeIDAllocator -mutable struct TypeIDAllocator - allocator::API.MlirTypeIDAllocator + mutable struct TypeIDAllocator + allocator::API.MlirTypeIDAllocator - function TypeIDAllocator() - ptr = API.mlirTypeIDAllocatorCreate() - @assert ptr != C_NULL "cannot create TypeIDAllocator" - finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + function TypeIDAllocator() + ptr = API.mlirTypeIDAllocatorCreate() + @assert ptr != C_NULL "cannot create TypeIDAllocator" + finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + end end -end -Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator -Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator + Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator + Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator -TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) + TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) else -struct TypeIDAllocator end + struct TypeIDAllocator end end include("./Support.jl") include("./Pass.jl") -end # module IR diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl index 7eef5b88..f4718dbe 100644 --- a/src/IR/Pass.jl +++ b/src/IR/Pass.jl @@ -9,13 +9,12 @@ end mutable struct PassManager pass::MlirPassManager - context::Context allocator::TypeIDAllocator passes::Dict{TypeID,ExternalPassHandle} - PassManager(pm::MlirPassManager, context) = begin + PassManager(pm::MlirPassManager) = begin @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" - finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + finalizer(new(pm, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm API.mlirPassManagerDestroy(pm.pass) end end @@ -30,8 +29,8 @@ function enable_verifier!(pm, enable=true) pm end -PassManager(context) = - PassManager(API.mlirPassManagerCreate(context), context) +PassManager() = + PassManager(API.mlirPassManagerCreate(context())) function run!(pm::PassManager, module_) status = API.mlirPassManagerRun(pm, module_) @@ -96,7 +95,7 @@ function add_pipeline!(op_pass::OpPassManager, pipeline) end op_pass end - + function add_owned_pass!(pm::PassManager, pass) API.mlirPassManagerAddOwnedPass(pm, pass) pm @@ -110,67 +109,66 @@ end @static if isdefined(API, :mlirCreateExternalPass) -### Pass + ### Pass -# AbstractPass interface: -opname(::AbstractPass) = "" -function pass_run(::Context, ::P, op) where {P<:AbstractPass} - error("pass $P does not implement `MLIR.pass_run`") -end + # AbstractPass interface: + opname(::AbstractPass) = "" + function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") + end -function _pass_construct(ptr::ExternalPassHandle) - nothing -end + function _pass_construct(ptr::ExternalPassHandle) + nothing + end -function _pass_destruct(ptr::ExternalPassHandle) - nothing -end + function _pass_destruct(ptr::ExternalPassHandle) + nothing + end -function _pass_initialize(ctx, handle::ExternalPassHandle) - try - handle.ctx = Context(ctx) - mlirLogicalResultSuccess() - catch - mlirLogicalResultFailure() + function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + mlirLogicalResultSuccess() + catch + mlirLogicalResultFailure() + end end -end -function _pass_clone(handle::ExternalPassHandle) - ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) -end + function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) + end -function _pass_run(rawop, external_pass, handle::ExternalPassHandle) - op = Operation(rawop, false) - try - pass_run(handle.ctx, handle.pass, op) - catch ex - @error "Something went wrong running pass" exception=(ex,catch_backtrace()) - API.mlirExternalPassSignalFailure(external_pass) + function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception = (ex, catch_backtrace()) + API.mlirExternalPassSignalFailure(external_pass) + end + nothing end - nothing -end -function create_external_pass!(oppass::OpPassManager, args...) - create_external_pass!(oppass.pass, args...) -end -function create_external_pass!(manager, pass, name, argument, - description, opname=opname(pass), - dependent_dialects=MlirDialectHandle[]) - passid = TypeID(manager.allocator) - callbacks = API.MlirExternalPassCallbacks( + function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) + end + function create_external_pass!(manager, pass, name, argument, + description, opname=opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = API.MlirExternalPassCallbacks( @cfunction(_pass_construct, Cvoid, (Any,)), @cfunction(_pass_destruct, Cvoid, (Any,)), @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), @cfunction(_pass_clone, Any, (Any,)), @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) - ) - pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) - userdata = Base.pointer_from_objref(pass_handle) - mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, - length(dependent_dialects), dependent_dialects, - callbacks, userdata) - mlir_pass -end - -end + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass + end +end \ No newline at end of file diff --git a/src/IR/state.jl b/src/IR/state.jl new file mode 100644 index 00000000..072f65ee --- /dev/null +++ b/src/IR/state.jl @@ -0,0 +1,43 @@ +# Global state + +# to simplify the API, we maintain a stack of contexts in task local storage +# and pass them implicitly to MLIR API's that require them. + +export context, activate, deactivate, context! + +using ..IR + +_has_context() = haskey(task_local_storage(), :MLIRContext) && + !isempty(task_local_storage(:MLIRContext)) + +function context(; throw_error::Core.Bool=true) + if !_has_context() + throw_error && error("No MLIR context is active") + return nothing + end + last(task_local_storage(:MLIRContext)) +end + +function activate(ctx::Context) + stack = get!(task_local_storage(), :MLIRContext) do + Context[] + end + push!(stack, ctx) + return +end + +function deactivate(ctx::Context) + context() == ctx || error("Deactivating wrong context") + pop!(task_local_storage(:MLIRContext)) +end + +function context!(f, ctx::Context) + activate(ctx) + try + f() + finally + deactivate(ctx) + end +end + + diff --git a/src/MLIR.jl b/src/MLIR.jl index 36638296..4d200805 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,7 +35,14 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -include("./IR/IR.jl") +module IR + import ..API: API + + include("./IR/IR.jl") + include("./IR/state.jl") +end # module IR + include("./Dialects.jl") + end # module MLIR