Skip to content

Commit

Permalink
Refactor and clean IR module (#62)
Browse files Browse the repository at this point in the history
* Reimplement `Context`

* Reimplement `Dialect`

* Reimplement `Location`

* Reimplement `Value`

* Reimplement `Block`

* Move `BlockIterator`

* Refactor `first_op`

* Refactor `mlirIsNull` calls in `Block`

* Reimplement `Region`

* Reimplement `Module`

* Fix namespaces in `Location`

* Refactor `Type` to `Core.Type` to avoid import

* Remove unused arg name from `iterator(::BlockIterator)`

* Fix namespace calls under `API`

* Refactor constructors

* Refactor `Context` code

* Refactor `Dialect` code

* Fix typo in exported symbol

* Fix typo in returned value in `enable_multithreading!`

* Document some code

* Refactor code

* Reimplement `Operation`

* Remove commented code

* Export symbols from `IR`

* Reimplement `TypeID`

* Move iterators to new file

* Fix missing module name in function calls

* Remove `MLIRType` from tblgen generator

* Refactor `MLIRType` from brutus example

* Remove duplicate `get_or_load_dialect!` method

* Fix missing module name in symbol refs

* Refactor `MLIRType` in tests

* Add `Identifier`

* Add `SymbolTable`

* Implement `IntegerSet`

* Add `AffineExpr` type

* Throw message on assert fail in `AffineExpr` constructor

* Add `AffineMap` type

* Fix `Context` kwargs

* Import `@affinemap` macro from #35

Co-authored-by: Paul Berg <[email protected]>

* Fix `@affinemap`

* Implement `LogicalResult`

* Fix typo in `move_after!`

* Add `rmfromparent!` for `Operation`

* Reimplement `Type`

* Reorder includes

* Fix `Type` clash in `LogicalResult`

* Reimplement `Attribute`

* Refactor `MLIRType` to `IR.Type` in dialect bindings

* Fix `mlir*IsNull` calls

`mlir*IsNull` are only defined in headers (no symbol in libs), so binding is generated but fails to look for the symbol

* Apply suggestion from @Pangoraw

Co-authored-by: Paul Berg <[email protected]>

* Apply suggestion from @Pangoraw

Co-authored-by: Paul Berg <[email protected]>

* Apply suggestion by @Pangoraw

Co-authored-by: Paul Berg <[email protected]>

* Fix typos

* Fix ambiguity

* Fix typos

* Remove redundant exports

* Fix type value retrieval of `Attribute` for integers

* Apply suggestion by @Pangoraw

* Refactor `@affinemap`

* Remove invalid assert in `@affinemap`

* Fix extra namespace inside macro

* Fix `mlir_str` macro

* Fix `verifyall`

* Rename `next_in_region` to `next`

* Add `OpOperand` type (LLVM 16)

* Fix `mlirOpPrintingFlagsEnableDebugInfo` on LLVM 15,16

* Add `@llvmversioned` macro utility

* Add Float8 types

* Add support for `DenseArray` attributes

* Remove `isopaquelements` on LLVM 16

* Fix typos in `OpOperand`

* Fix namespace in `DenseArrayAttribute`

* Refactor `operandsegmentsizes`

* Fix import in Julia 1.10

* Fix max versioning in `@llvmversioned`

* Fix `DenseElementsAttribute` docstrings

* Fix docstring warning in methods with `@llvmversioned`

* Fix docstring of `type`

* Fix `mlirIsNull` calls

* Fix `LogicalResult` methods

`mlirLogicalResult*` functions in the C-API are header-only

* Fix `LogicalResult` calls in pass infrastructure

* Update Brutus example

* Comment `mlirOperationWriteBytecode`

* Use signless integer types in MLIR for `<:Signed` types in Julia

* Fix `Bool` to MLIR conversion

* Fix tests

* Update docs

* Fix 0-indexing of dims and symbol expressions in `@affinemap`

* Refactor `PassManager`,`OpPassManager`

* Add `ExecutionEngine`

* Fix docstring of `Type(Complex{T})`

* Refactor `TensorType` constructor

* Refactor `push_argument!`

* Apply suggestion by @Pangoraw

Co-authored-by: Paul Berg <[email protected]>

* Refactor `create_operation`

* Fix `push_arguent!`

* Try fix `load_all_available_dialects` call in Julia 1.9

---------

Co-authored-by: Paul Berg <[email protected]>
  • Loading branch information
mofeing and Pangoraw authored Mar 8, 2024
1 parent 8267af4 commit 3527e24
Show file tree
Hide file tree
Showing 149 changed files with 18,904 additions and 16,195 deletions.
8 changes: 4 additions & 4 deletions deps/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper,
const char *moduleTemplate;
if (disableModuleWrap)
{
moduleTemplate = R"(import ...IR: NamedAttribute, MLIRType, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
moduleTemplate = R"(import ...IR: IR, NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
import ..Dialects: namedattribute, operandsegmentsizes
import ...API
Expand All @@ -181,7 +181,7 @@ import ...API
{
moduleTemplate = R"(module {0}
import ...IR: NamedAttribute, MLIRType, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
import ..Dialects: namedattribute, operandsegmentsizes
import ...API
Expand All @@ -196,7 +196,7 @@ function {0}({1}location=Location())
{2}
end
)"; // 0: functionname, 1: functionarguments, 2: functionbody
const char *functionbodytemplate = R"(results = MLIRType[{0}]
const char *functionbodytemplate = R"(results = IR.Type[{0}]
operands = Value[{1}]
owned_regions = Region[{2}]
successors = Block[{3}]
Expand Down Expand Up @@ -321,7 +321,7 @@ end
resultname = "result_" + std::to_string(i);
}
resultname = sanitizeName(resultname);
std::string type = "MLIRType";
std::string type = "IR.Type";

bool optional = named_result.isOptional() || inferrable;
bool variadic = named_result.isVariadic();
Expand Down
22 changes: 22 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ makedocs(;
pages = [
"Home" => "index.md",
"Examples" => examples,
"IR" => [
"AffineExpr" => "IR/affineexpr.md",
"AffineMap" => "IR/affinemap.md",
"Attribute" => "IR/attribute.md",
"Block" => "IR/block.md",
"Context" => "IR/context.md",
"Dialect" => "IR/dialect.md",
"Identifier" => "IR/identifier.md",
"IntegerSet" => "IR/integerset.md",
"Iterators" => "IR/iterators.md",
"Location" => "IR/location.md",
"LogicalResult" => "IR/logicalresult.md",
"Module" => "IR/module.md",
"Operation" => "IR/operation.md",
# TODO `OpOperand` for LLVM >=16
"Pass Infrastucture" => "IR/pass.md",
"Region" => "IR/region.md",
"SymbolTable" => "IR/symboltable.md",
"Type" => "IR/type.md",
"TypeID" => "IR/typeid.md",
"Value" => "IR/value.md",
],
"Dialects" => [
"15" => [
"`affine`" => "dialects/affine.md",
Expand Down
6 changes: 6 additions & 0 deletions docs/src/IR/affineexpr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Affine Expressions

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/AffineExpr.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/affinemap.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Affine Map

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/AffineMap.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/attribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Attribute

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Attribute.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/block.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Block

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Block.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Context

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Context.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/dialect.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Dialect

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Dialect.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/identifier.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Identifier

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Identifier.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/integerset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Integer Set

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/IntegerSet.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/iterators.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Iterators

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Iterators.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/location.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Location

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Location.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/logicalresult.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Logical Result

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/LogicalResult.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/module.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Module

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Module.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/operation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Operation

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Operation.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/pass.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Pass Infrastructure

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Pass.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/region.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Region

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Region.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/symboltable.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Symbol Table

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/SymbolTable.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/type.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Type

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Type.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/typeid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# TypeID

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/TypeID.jl"]
```
6 changes: 6 additions & 0 deletions docs/src/IR/value.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Value

```@autodocs
Modules = [MLIR.IR]
Pages = ["IR/Value.jl"]
```
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Order = [:macro, :function]
```@autodocs
Modules = [MLIR.API]
Order = [:module, :type, :constant, :macro, :function]
```
```
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Design

### String and MlirStringRef.
### String and MlirStringRef

`MlirStringRef` is a non-owning pointer, the caller is in charge of performing necessary
copies or ensuring that the pointee outlives all uses of `MlirStringRef`.
Since Julia is a GC'd language special care must be taken around the live-time of Julia
Expand Down
49 changes: 24 additions & 25 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const uge = 9
end

function cmpi_pred(predicate)
function (ops...; location = Location())
arith.cmpi(ops...; result=IR.MLIRType(Bool), predicate, location)
function (ops...; location=Location())
arith.cmpi(ops...; result=IR.Type(Bool), predicate, location)
end
end

Expand All @@ -45,10 +45,10 @@ const intrinsics_to_mlir = Dict([
Base.mul_float => single_op_wrapper(arith.mulf),
Base.not_int => function (block, args; location=Location())
arg = only(args)
mT = IR.get_type(arg)
mT = IR.type(arg)
T = IR.julia_type(mT)
ones = push!(block, arith.constant(value=typemax(UInt64) % T;
result=mT, location)) |> IR.get_result
result=mT, location)) |> IR.result
push!(block, arith.xori(arg, ones; location))
end,
])
Expand All @@ -63,7 +63,7 @@ function prepare_block(ir, bb)
inst isa Core.PhiNode || continue

type = stmt[:type]
IR.push_argument!(b, MLIRType(type), Location())
IR.push_argument!(b, IR.Type(type))
end

return b
Expand Down Expand Up @@ -107,9 +107,10 @@ function code_mlir(f, types)

values = Vector{Value}(undef, length(ir.stmts))

for dialect in ("func", "cf")
IR.get_or_load_dialect!(dialect)
for dialect in (:func, :cf)
IR.register_dialect!(IR.DialectHandle(dialect))
end
IR.load_all_available_dialects()

blocks = [
prepare_block(ir, bb)
Expand All @@ -119,17 +120,17 @@ function code_mlir(f, types)
current_block = entry_block = blocks[begin]

for argtype in types.parameters
IR.push_argument!(entry_block, MLIRType(argtype), Location())
IR.push_argument!(entry_block, IR.Type(argtype))
end

function get_value(x)::Value
if x isa Core.SSAValue
@assert isassigned(values, x.id) "value $x was not assigned"
values[x.id]
elseif x isa Core.Argument
IR.get_argument(entry_block, x.n - 1)
IR.argument(entry_block, x.n - 1)
elseif x isa BrutusScalar
IR.get_result(push!(current_block, arith.constant(;value=x)))
IR.result(push!(current_block, arith.constant(; value=x)))
else
error("could not use value $x inside MLIR")
end
Expand All @@ -142,14 +143,14 @@ function code_mlir(f, types)
for sidx in bb.stmts
stmt = ir.stmts[sidx]
inst = stmt[:inst]
line = ir.linetable[stmt[:line]]
line = ir.linetable[stmt[:line] + 1]

if Meta.isexpr(inst, :call)
val_type = stmt[:type]
if !(val_type <: BrutusScalar)
error("type $val_type is not supported")
end
out_type = MLIRType(val_type)
out_type = IR.Type(val_type)

called_func = first(inst.args)
if called_func isa GlobalRef # TODO: should probably use something else here
Expand All @@ -160,11 +161,11 @@ function code_mlir(f, types)
args = get_value.(@view inst.args[begin+1:end])

location = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; location))
res = IR.result(fop!(current_block, args; location))

values[sidx] = res
elseif inst isa PhiNode
values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1)
values[sidx] = IR.argument(current_block, n_phi_nodes += 1)
elseif inst isa PiNode
values[sidx] = get_value(inst.val)
elseif inst isa GotoNode
Expand All @@ -185,7 +186,7 @@ function code_mlir(f, types)
cond_br = cf.cond_br(cond, true_args, false_args; trueDest=other_dest, falseDest=dest, location)
push!(current_block, cond_br)
elseif inst isa ReturnNode
line = ir.linetable[stmt[:line]]
line = ir.linetable[stmt[:line]+1]
location = Location(string(line.file), line.line, 0)
push!(current_block, func.return_([get_value(inst.val)]; location))
elseif Meta.isexpr(inst, :code_coverage_effect)
Expand All @@ -203,21 +204,19 @@ function code_mlir(f, types)
push!(region, b)
end

LLVM15 = LLVM.version() >= v"15"

input_types = MLIRType[
IR.get_type(IR.get_argument(entry_block, i))
for i in 1:IR.num_arguments(entry_block)
input_types = IR.Type[
IR.type(IR.argument(entry_block, i))
for i in 1:IR.nargs(entry_block)
]
result_types = [MLIRType(ret)]
result_types = [IR.Type(ret)]

ftype = MLIRType(input_types => result_types)
ftype = IR.FunctionType(input_types, result_types)
op = IR.create_operation(
"func.func",
Location();
attributes=[
NamedAttribute("sym_name", IR.Attribute(string(func_name))),
NamedAttribute("function_type", IR.Attribute(ftype)),
IR.NamedAttribute("sym_name", IR.Attribute(string(func_name))),
IR.NamedAttribute("function_type", IR.Attribute(ftype)),
],
owned_regions=Region[region],
result_inference=false,
Expand Down Expand Up @@ -277,7 +276,7 @@ fptr = IR.context!(IR.Context()) do
op = Brutus.code_mlir(pow, Tuple{Int,Int})

mod = IR.Module(Location())
body = IR.get_body(mod)
body = IR.body(mod)
push!(body, op)

pm = IR.PassManager()
Expand Down
10 changes: 2 additions & 8 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Dialects

import LLVM
import ..IR: Attribute, NamedAttribute, DenseArrayAttribute, context
import ..IR: Attribute, NamedAttribute, context
import ..API

namedattribute(name, val) = namedattribute(name, Attribute(val))
Expand All @@ -11,13 +11,7 @@ function namedattribute(name, val::NamedAttribute)
return val
end

operandsegmentsizes(segments) =
namedattribute("operand_segment_sizes",
LLVM.version() >= v"16" ?
DenseArrayAttribute(Int32.(segments)) :
Attribute(Int32.(segments))
)

operandsegmentsizes(segments) = namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))

let
ver = string(LLVM.version().major)
Expand Down
Loading

0 comments on commit 3527e24

Please sign in to comment.