Skip to content

Commit

Permalink
Fix Brutus example on nightly. (#75)
Browse files Browse the repository at this point in the history
* add v17 to API.jl

* linetable is stored in debuginfo field now

* using internal IRShow method to get linenumber

* I forgot the dialects!

* small error

* Update examples/brutus.jl

Co-authored-by: Paul Berg <[email protected]>

* Update examples/brutus.jl

Co-authored-by: Paul Berg <[email protected]>

* MLIR17: fix brutus example location info, update Pass.jl, update executionengine example

---------

Co-authored-by: Paul Berg <[email protected]>
  • Loading branch information
jumerckx and Pangoraw authored Jun 4, 2024
1 parent eb4aed0 commit b8ab825
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 10 deletions.
12 changes: 10 additions & 2 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,16 @@ function code_mlir(f, types)
for sidx in bb.stmts
stmt = ir.stmts[sidx]
inst = stmt[:inst]
line = ir.linetable[stmt[:line]+1]
line = @static if VERSION <= v"1.11"
ir.linetable[stmt[:line]+1]
else
lineinfonode = Base.IRShow.buildLineInfoNode(ir.debuginfo, :var"n/a", sidx)
if !isempty(lineinfonode)
last(lineinfonode)
else
(; ((:file, :line) .=> Base.IRShow.debuginfo_firstline(ir.debuginfo))...)
end
end

if Meta.isexpr(inst, :call)
val_type = stmt[:type]
Expand Down Expand Up @@ -186,7 +195,6 @@ function code_mlir(f, types)
cond_br = cf.cond_br(cond, true_args, false_args; trueDest=other_dest, falseDest=dest, location)
push!(current_block, cond_br)
elseif inst isa ReturnNode
line = ir.linetable[stmt[:line]+1]
location = Location(string(line.file), line.line, 0)
push!(current_block, func.return_([get_value(inst.val)]; location))
elseif Meta.isexpr(inst, :code_coverage_effect)
Expand Down
4 changes: 2 additions & 2 deletions src/API/API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ end

# generate version-less API functions
begin
local ops = mapreduce(, [v14, v15, v16]) do mod
local ops = mapreduce(, [v14, v15, v16, v17]) do mod
filter(names(mod; all=true)) do name
name [nameof(mod), :eval, :include] && !startswith(string(name), '#')
end
end

for op in ops
container_mods = filter([v14, v15, v16]) do mod
container_mods = filter([v14, v15, v16, v17]) do mod
op in names(mod; all=true)
end
container_mods = map(container_mods) do mod
Expand Down
6 changes: 3 additions & 3 deletions src/Dialects/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

begin
# list dialect operations
local dialectops = mapreduce(mergewith!(), [v14, v15, v16]) do mod
local dialectops = mapreduce(mergewith!(), [v14, v15, v16, v17]) do mod
dialects = filter(names(mod; all=true)) do dialect
dialect [nameof(mod), :eval, :include] && !startswith(string(dialect), '#')
end
Expand All @@ -33,11 +33,11 @@ begin
for (dialect, ops) in dialectops
mod = @eval module $dialect
using ...MLIR: MLIR_VERSION, MLIRException
using ..Dialects: v14, v15, v16
using ..Dialects: v14, v15, v16, v17
end

for op in ops
container_mods = filter([v14, v15, v16]) do mod
container_mods = filter([v14, v15, v16, v17]) do mod
dialect in names(mod; all=true) &&
op in names(getproperty(mod, dialect); all=true)
end
Expand Down
6 changes: 5 additions & 1 deletion src/IR/Pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ end
Run the provided `passManager` on the given `module`.
"""
function run!(pm::PassManager, mod::Module)
status = LogicalResult(API.mlirPassManagerRun(pm, mod))
status = if MLIR_VERSION[] >= v"17"
LogicalResult(API.mlirPassManagerRunOnOp(pm, Operation(mod)))
else
LogicalResult(API.mlirPassManagerRun(pm, mod))
end
if isfailure(status)
throw("failed to run pass manager on module")
end
Expand Down
13 changes: 11 additions & 2 deletions test/executionengine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ function lowerModuleToLLVM(ctx, mod)
op = "builtin.func"
end
opm = MLIR.API.mlirPassManagerGetNestedUnder(pm, op)
if LLVM.version() >= v"15"
if LLVM.version() >= v"17"
MLIR.API.mlirPassManagerAddOwnedPass(
pm, MLIR.API.mlirCreateConversionConvertFuncToLLVMPass()
)
elseif LLVM.version() >= v"15"
MLIR.API.mlirPassManagerAddOwnedPass(
pm, MLIR.API.mlirCreateConversionConvertFuncToLLVM()
)
Expand All @@ -43,7 +47,12 @@ function lowerModuleToLLVM(ctx, mod)
opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()
)
end
status = MLIR.API.mlirPassManagerRun(pm, mod)
status = if LLVM.version() >= v"17"
op = MLIR.API.mlirModuleGetOperation(mod)
MLIR.API.mlirPassManagerRunOnOp(pm, op)
else
MLIR.API.mlirPassManagerRun(pm, mod)
end
# undefined symbol: mlirLogicalResultIsFailure
if status.value == 0
error("Unexpected failure running pass failure")
Expand Down

0 comments on commit b8ab825

Please sign in to comment.