Skip to content

Commit

Permalink
add LLVM.jl-like context state handling. (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx authored Oct 2, 2023
1 parent 0a447e6 commit d239921
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 211 deletions.
63 changes: 32 additions & 31 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode
const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64}

function cmpi_pred(predicate)
function(ctx, ops; loc=Location(ctx))
arith.cmpi(ctx, predicate, ops; loc)
function(ops; loc=Location())
arith.cmpi(predicate, ops; loc)
end
end

function single_op_wrapper(fop)
(ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc))
(block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc))
end

const intrinsics_to_mlir = Dict([
Expand All @@ -24,15 +24,15 @@ const intrinsics_to_mlir = Dict([
Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)),
Base.mul_int => single_op_wrapper(arith.muli),
Base.mul_float => single_op_wrapper(arith.mulf),
Base.not_int => function(ctx, block, args; loc=Location(ctx))
Base.not_int => function(block, args; loc=Location())
arg = only(args)
ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result
push!(block, arith.xori(ctx, Value[arg, ones]; loc))
ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result
push!(block, arith.xori(Value[arg, ones]; loc))
end,
])

"Generates a block argument for each phi node present in the block."
function prepare_block(ctx, ir, bb)
function prepare_block(ir, bb)
b = Block()

for sidx in bb.stmts
Expand All @@ -41,7 +41,7 @@ function prepare_block(ctx, ir, bb)
inst isa Core.PhiNode || continue

type = stmt[:type]
IR.push_argument!(b, MLIRType(ctx, type), Location(ctx))
IR.push_argument!(b, MLIRType(type), Location())
end

return b
Expand All @@ -68,7 +68,7 @@ function collect_value_arguments(ir, from, to)
end

"""
code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation
code_mlir(f, types::Type{Tuple}) -> IR.Operation
Returns a `func.func` operation corresponding to the ircode of the provided method.
This only supports a few Julia Core primitives and scalar types of type $BrutusScalar.
Expand All @@ -78,25 +78,26 @@ This only supports a few Julia Core primitives and scalar types of type $BrutusS
handful of primitives. A better to perform this conversion would to create a dialect
representing Julia IR and progressively lower it to base MLIR dialects.
"""
function code_mlir(f, types; ctx=Context())
function code_mlir(f, types)
ctx = context()
ir, ret = Core.Compiler.code_ircode(f, types) |> only
@assert first(ir.argtypes) isa Core.Const

values = Vector{Value}(undef, length(ir.stmts))

for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",))
IR.get_or_load_dialect!(ctx, dialect)
IR.get_or_load_dialect!(dialect)
end

blocks = [
prepare_block(ctx, ir, bb)
prepare_block(ir, bb)
for bb in ir.cfg.blocks
]

current_block = entry_block = blocks[begin]

for argtype in types.parameters
IR.push_argument!(entry_block, MLIRType(ctx, argtype), Location(ctx))
IR.push_argument!(entry_block, MLIRType(argtype), Location())
end

function get_value(x)::Value
Expand All @@ -106,7 +107,7 @@ function code_mlir(f, types; ctx=Context())
elseif x isa Core.Argument
IR.get_argument(entry_block, x.n - 1)
elseif x isa BrutusScalar
IR.get_result(push!(current_block, arith.constant(ctx, x)))
IR.get_result(push!(current_block, arith.constant(x)))
else
error("could not use value $x inside MLIR")
end
Expand All @@ -126,7 +127,7 @@ function code_mlir(f, types; ctx=Context())
if !(val_type <: BrutusScalar)
error("type $val_type is not supported")
end
out_type = MLIRType(ctx, val_type)
out_type = MLIRType(val_type)

called_func = first(inst.args)
if called_func isa GlobalRef # TODO: should probably use something else here
Expand All @@ -136,8 +137,8 @@ function code_mlir(f, types; ctx=Context())
fop! = intrinsics_to_mlir[called_func]
args = get_value.(@view inst.args[begin+1:end])

loc = Location(ctx, string(line.file), line.line, 0)
res = IR.get_result(fop!(ctx, current_block, args; loc))
loc = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; loc))

values[sidx] = res
elseif inst isa PhiNode
Expand All @@ -147,9 +148,9 @@ function code_mlir(f, types; ctx=Context())
elseif inst isa GotoNode
args = get_value.(collect_value_arguments(ir, block_id, inst.label))
dest = blocks[inst.label]
loc = Location(ctx, string(line.file), line.line, 0)
loc = Location(string(line.file), line.line, 0)
brop = LLVM.version() >= v"15" ? cf.br : std.br
push!(current_block, brop(ctx, dest, args; loc))
push!(current_block, brop(dest, args; loc))
elseif inst isa GotoIfNot
false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest))
cond = get_value(inst.cond)
Expand All @@ -159,15 +160,15 @@ function code_mlir(f, types; ctx=Context())
other_dest = blocks[other_dest]
dest = blocks[inst.dest]

loc = Location(ctx, string(line.file), line.line, 0)
loc = Location(string(line.file), line.line, 0)
cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br
cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc)
cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc)
push!(current_block, cond_br)
elseif inst isa ReturnNode
line = ir.linetable[stmt[:line]]
retop = LLVM.version() >= v"15" ? func.return_ : std.return_
loc = Location(ctx, string(line.file), line.line, 0)
push!(current_block, retop(ctx, [get_value(inst.val)]; loc))
loc = Location(string(line.file), line.line, 0)
push!(current_block, retop([get_value(inst.val)]; loc))
elseif Meta.isexpr(inst, :code_coverage_effect)
# Skip
else
Expand All @@ -189,15 +190,15 @@ function code_mlir(f, types; ctx=Context())
IR.get_type(IR.get_argument(entry_block, i))
for i in 1:IR.num_arguments(entry_block)
]
result_types = [MLIRType(ctx, ret)]
result_types = [MLIRType(ret)]

ftype = MLIRType(ctx, input_types => result_types)
ftype = MLIRType(input_types => result_types)
op = IR.create_operation(
LLVM15 ? "func.func" : "builtin.func",
Location(ctx);
Location();
attributes = [
NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))),
NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)),
NamedAttribute("sym_name", IR.Attribute(string(func_name))),
NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)),
],
owned_regions = Region[region],
result_inference=false,
Expand Down Expand Up @@ -254,13 +255,13 @@ using MLIR.IR, MLIR
ctx = Context()
# IR.enable_multithreading!(ctx, false)

op = Brutus.code_mlir(pow, Tuple{Int, Int}; ctx)
op = Brutus.code_mlir(pow, Tuple{Int, Int})

mod = MModule(ctx, Location(ctx))
mod = MModule(Location())
body = IR.get_body(mod)
push!(body, op)

pm = IR.PassManager(ctx)
pm = IR.PassManager()
opm = IR.OpPassManager(pm)

# IR.enable_ir_printing!(pm)
Expand Down
50 changes: 25 additions & 25 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ for (f, t) in Iterators.product(
(:i, :f),
)
fname = Symbol(f, t)
@eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context))
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end

for fname in (:xori, :andi, :ori)
@eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context))
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end
Expand All @@ -25,37 +25,37 @@ for (f, t) in Iterators.product(
(:si, :ui, :f),
)
fname = Symbol(f, t)
@eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context))
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop
for f in (:index_cast, :index_castui)
@eval function $f(context, operand; loc=Location(context))
@eval function $f(operand; loc=Location())
IR.create_operation(
$(string("arith.", f)),
loc;
operands=[operand],
results=[IR.IndexType(context)],
results=[IR.IndexType()],
)
end
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop
function extf(context, operand, type; loc=Location(context))
function extf(operand, type; loc=Location())
IR.create_operation("arith.exf", loc; operands=[operand], results=[type])
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop
function constant(context, value, type=MLIRType(context, typeof(value)); loc=Location(context))
function constant(value, type=MLIRType(typeof(value)); loc=Location())
IR.create_operation(
"arith.constant",
loc;
results=[type],
attributes=[
IR.NamedAttribute(context, "value",
Attribute(context, value, type)),
IR.NamedAttribute("value",
Attribute(value, type)),
],
)
end
Expand All @@ -73,15 +73,15 @@ module Predicates
const uge = 9
end

function cmpi(context, predicate, operands; loc=Location(context))
function cmpi(predicate, operands; loc=Location())
IR.create_operation(
"arith.cmpi",
loc;
operands,
results=[MLIRType(context, Bool)],
results=[MLIRType(Bool)],
attributes=[
IR.NamedAttribute(context, "predicate",
Attribute(context, predicate))
IR.NamedAttribute("predicate",
Attribute(predicate))
],
)
end
Expand All @@ -93,29 +93,29 @@ module std

using ...IR

function return_(context, operands; loc=Location(context))
function return_(operands; loc=Location())
IR.create_operation("std.return", loc; operands, result_inference=false)
end

function br(context, dest, operands; loc=Location(context))
function br(dest, operands; loc=Location())
IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false)
end

function cond_br(
context, cond,
cond,
true_dest, false_dest,
true_dest_operands,
false_dest_operands;
loc=Location(context),
loc=Location(),
)
IR.create_operation(
"std.cond_br",
loc;
successors=[true_dest, false_dest],
operands=[cond, true_dest_operands..., false_dest_operands...],
attributes=[
IR.NamedAttribute(context, "operand_segment_sizes",
IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)]))
IR.NamedAttribute("operand_segment_sizes",
IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)]))
],
result_inference=false,
)
Expand All @@ -128,7 +128,7 @@ module func

using ...IR

function return_(context, operands; loc=Location(context))
function return_(operands; loc=Location())
IR.create_operation("func.return", loc; operands, result_inference=false)
end

Expand All @@ -138,24 +138,24 @@ module cf

using ...IR

function br(context, dest, operands; loc=Location(context))
function br(dest, operands; loc=Location())
IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false)
end

function cond_br(
context, cond,
cond,
true_dest, false_dest,
true_dest_operands,
false_dest_operands;
loc=Location(context),
loc=Location(),
)
IR.create_operation(
"cf.cond_br", loc;
operands=[cond, true_dest_operands..., false_dest_operands...],
successors=[true_dest, false_dest],
attributes=[
IR.NamedAttribute(context, "operand_segment_sizes",
IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)]))
IR.NamedAttribute("operand_segment_sizes",
IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)]))
],
result_inference=false,
)
Expand Down
Loading

0 comments on commit d239921

Please sign in to comment.