Skip to content

Commit

Permalink
Update toy brutus (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw authored Jan 16, 2024
1 parent beee789 commit c370d5c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 89 deletions.
118 changes: 69 additions & 49 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,55 @@
"""
Brutus is a toy implementation of a Julia typed IR to MLIR conversion, the name
is a reference to the [brutus](https://github.com/JuliaLabs/brutus) project from
the MIT JuliaLabs which performs a similar conversion (with a lot more language constructs supported)
but from C++.
"""
module Brutus

import LLVM
using MLIR.IR
using MLIR.Dialects: arith, func, cf, std
using MLIR.Dialects: arith, func, cf
using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode

const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64}

module Predicates
const eq = 0
const ne = 1
const slt = 2
const sle = 3
const sgt = 4
const sge = 5
const ult = 6
const ule = 7
const ugt = 8
const uge = 9
end

function cmpi_pred(predicate)
function(ops; loc=Location())
arith.cmpi(predicate, ops; loc)
function (ops...; location = Location())
arith.cmpi(ops...; result=IR.MLIRType(Bool), predicate, location)
end
end

function single_op_wrapper(fop)
(block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc))
(block::Block, args::Vector{Value}; location=Location()) -> push!(block, fop(args...; location))
end

const intrinsics_to_mlir = Dict([
Base.add_int => single_op_wrapper(arith.addi),
Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)),
Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)),
Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)),
Base.sle_int => single_op_wrapper(cmpi_pred(Predicates.sle)),
Base.slt_int => single_op_wrapper(cmpi_pred(Predicates.slt)),
Base.:(===) => single_op_wrapper(cmpi_pred(Predicates.eq)),
Base.mul_int => single_op_wrapper(arith.muli),
Base.mul_float => single_op_wrapper(arith.mulf),
Base.not_int => function(block, args; loc=Location())
Base.not_int => function (block, args; location=Location())
arg = only(args)
ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result
push!(block, arith.xori(Value[arg, ones]; loc))
mT = IR.get_type(arg)
T = IR.julia_type(mT)
ones = push!(block, arith.constant(value=typemax(UInt64) % T;
result=mT, location)) |> IR.get_result
push!(block, arith.xori(arg, ones; location))
end,
])

Expand Down Expand Up @@ -85,7 +107,7 @@ function code_mlir(f, types)

values = Vector{Value}(undef, length(ir.stmts))

for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",))
for dialect in ("func", "cf")
IR.get_or_load_dialect!(dialect)
end

Expand All @@ -107,7 +129,7 @@ function code_mlir(f, types)
elseif x isa Core.Argument
IR.get_argument(entry_block, x.n - 1)
elseif x isa BrutusScalar
IR.get_result(push!(current_block, arith.constant(x)))
IR.get_result(push!(current_block, arith.constant(;value=x)))
else
error("could not use value $x inside MLIR")
end
Expand Down Expand Up @@ -137,8 +159,8 @@ function code_mlir(f, types)
fop! = intrinsics_to_mlir[called_func]
args = get_value.(@view inst.args[begin+1:end])

loc = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; loc))
location = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; location))

values[sidx] = res
elseif inst isa PhiNode
Expand All @@ -148,9 +170,8 @@ function code_mlir(f, types)
elseif inst isa GotoNode
args = get_value.(collect_value_arguments(ir, block_id, inst.label))
dest = blocks[inst.label]
loc = Location(string(line.file), line.line, 0)
brop = LLVM.version() >= v"15" ? cf.br : std.br
push!(current_block, brop(dest, args; loc))
location = Location(string(line.file), line.line, 0)
push!(current_block, cf.br(args; dest, location))
elseif inst isa GotoIfNot
false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest))
cond = get_value(inst.cond)
Expand All @@ -160,15 +181,13 @@ function code_mlir(f, types)
other_dest = blocks[other_dest]
dest = blocks[inst.dest]

loc = Location(string(line.file), line.line, 0)
cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br
cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc)
location = Location(string(line.file), line.line, 0)
cond_br = cf.cond_br(cond, true_args, false_args; trueDest=other_dest, falseDest=dest, location)
push!(current_block, cond_br)
elseif inst isa ReturnNode
line = ir.linetable[stmt[:line]]
retop = LLVM.version() >= v"15" ? func.return_ : std.return_
loc = Location(string(line.file), line.line, 0)
push!(current_block, retop([get_value(inst.val)]; loc))
location = Location(string(line.file), line.line, 0)
push!(current_block, func.return_([get_value(inst.val)]; location))
elseif Meta.isexpr(inst, :code_coverage_effect)
# Skip
else
Expand All @@ -194,13 +213,13 @@ function code_mlir(f, types)

ftype = MLIRType(input_types => result_types)
op = IR.create_operation(
LLVM15 ? "func.func" : "builtin.func",
"func.func",
Location();
attributes = [
attributes=[
NamedAttribute("sym_name", IR.Attribute(string(func_name))),
NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)),
NamedAttribute("function_type", IR.Attribute(ftype)),
],
owned_regions = Region[region],
owned_regions=Region[region],
result_inference=false,
)

Expand All @@ -227,16 +246,18 @@ macro code_mlir(call)
end
end

export code_mlir, @code_mlir

end # module Brutus

# ---

function pow(x::F, n) where {F}
p = one(F)
for _ in 1:n
p *= x
end
p
p = one(F)
for _ in 1:n
p *= x
end
p
end

function f(x)
Expand All @@ -252,29 +273,28 @@ end
using Test
using MLIR.IR, MLIR

ctx = Context()
# IR.enable_multithreading!(ctx, false)
fptr = IR.context!(IR.Context()) do
op = Brutus.code_mlir(pow, Tuple{Int,Int})

op = Brutus.code_mlir(pow, Tuple{Int, Int})
mod = MModule(Location())
body = IR.get_body(mod)
push!(body, op)

mod = MModule(Location())
body = IR.get_body(mod)
push!(body, op)
pm = IR.PassManager()
opm = IR.OpPassManager(pm)

pm = IR.PassManager()
opm = IR.OpPassManager(pm)
# IR.enable_ir_printing!(pm)
IR.enable_verifier!(pm, true)

# IR.enable_ir_printing!(pm)
IR.enable_verifier!(pm, true)
MLIR.API.mlirRegisterAllPasses()
MLIR.API.mlirRegisterAllLLVMTranslations(IR.context())
IR.add_pipeline!(opm, "convert-arith-to-llvm,convert-func-to-llvm")

MLIR.API.mlirRegisterAllPasses()
MLIR.API.mlirRegisterAllLLVMTranslations(ctx)
IR.add_pipeline!(opm, Brutus.LLVM.version() >= v"15" ? "convert-arith-to-llvm,convert-func-to-llvm" : "convert-std-to-llvm")
IR.run!(pm, mod)

IR.run!(pm, mod)

jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
fptr = MLIR.API.mlirExecutionEngineLookup(jit, "pow")
jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
MLIR.API.mlirExecutionEngineLookup(jit, "pow")
end

x, y = 3, 4

Expand Down
8 changes: 1 addition & 7 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@ function namedattribute(name, val::NamedAttribute)
return val
end

operandsegmentsizes(segments) = namedattribute(
"operand_segment_sizes",
Attribute(API.mlirDenseI32ArrayGet(
context().context,
length(segments),
Int32.(segments)
)))
operandsegmentsizes(segments) = namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))

let
ver = string(LLVM.version().major)
Expand Down
31 changes: 0 additions & 31 deletions src/IR/IR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,37 +217,6 @@ function Base.eltype(type::MLIRType)
end
end

function show_inner(io::IO, type::MLIRType)
if API.mlirTypeIsAInteger(type)
is_signless = API.mlirIntegerTypeIsSignless(type)
is_signed = API.mlirIntegerTypeIsSigned(type)

width = API.mlirIntegerTypeGetWidth(type)
t = if is_signed
"si"
elseif is_signless
"i"
else
"u"
end
print(io, t, width)
elseif API.mlirTypeIsAF64(type)
print(io, "f64")
elseif API.mlirTypeIsAF32(type)
print(io, "f32")
elseif API.mlirTypeIsARankedTensor(type)
print(io, "tensor<")
s = size(type)
print(io, join(s, "x"), "x")
show_inner(io, eltype(type))
print(io, ">")
elseif API.mlirTypeIsAIndex(type)
print(io, "index")
else
print(io, "unknown")
end
end

function Base.show(io::IO, type::MLIRType)
print(io, "MLIRType(#= ")
c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any))
Expand Down
2 changes: 1 addition & 1 deletion src/MLIR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Base.cconvert(::Type{API.MlirStringRef}, s::AbstractString) =
# Directly create `MlirStringRef` instead of adding an extra ccall.
function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, AbstractVector{UInt8}})
p = Base.unsafe_convert(Ptr{Cchar}, s)
return API.MlirStringRef(p, length(s))
return API.MlirStringRef(p, sizeof(s))
end

module IR
Expand Down
4 changes: 3 additions & 1 deletion test/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ end

examples_dir = joinpath(@__DIR__, "..", "examples")
examples = find_sources(examples_dir)

filter!(file -> readline(file) != "# EXCLUDE FROM TESTING", examples)
filter!(file -> !occursin("Kaleidoscope", file), examples)
filter!(file -> VERSION >= v"1.10" || !contains(file, "brutus.jl"), examples)

cd(examples_dir) do
examples = relpath.(examples, Ref(examples_dir))
Expand All @@ -28,4 +30,4 @@ cd(examples_dir) do
end
end

end
end
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,12 @@ using Test

include("examples.jl")
include("executionengine.jl")

@testset "MlirStringRef conversion" begin
s = "mlir 😄 α γ 🍕"

ms = Base.unsafe_convert(MLIR.API.MlirStringRef, s)
reconstructed = unsafe_string(Ptr{Cchar}(ms.data), ms.length)

@test s == reconstructed
end

0 comments on commit c370d5c

Please sign in to comment.