Skip to content

Commit

Permalink
Complete Iterator interface implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Mar 13, 2024
1 parent 3527e24 commit 53536a1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/IR/Iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ struct BlockIterator
region::Region
end

Base.IteratorSize(::Core.Type{BlockIterator}) = Base.SizeUnknown()
Base.eltype(::BlockIterator) = Block

function Base.iterate(it::BlockIterator)
reg = it.region
raw_block = API.mlirRegionGetFirstBlock(reg)
Expand Down Expand Up @@ -37,6 +40,9 @@ struct RegionIterator
op::Operation
end

Base.eltype(::RegionIterator) = Region
Base.length(it::RegionIterator) = nregions(it.op)

function Base.iterate(it::RegionIterator)
raw_region = API.mlirOperationGetFirstRegion(it.op)
if mlirIsNull(raw_region)
Expand Down Expand Up @@ -66,6 +72,9 @@ struct OperationIterator
block::Block
end

Base.IteratorSize(::Core.Type{OperationIterator}) = Base.SizeUnknown()
Base.eltype(::OperationIterator) = Operation

function Base.iterate(it::OperationIterator)
raw_op = API.mlirBlockGetFirstOperation(it.block)
if mlirIsNull(raw_op)
Expand Down
39 changes: 39 additions & 0 deletions test/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,42 @@ end
@test_throws AssertionError IR.Module(arith.constant(; value=true, result=IR.Type(Bool)))
end
end

@testset "Iterators" begin
IR.context!(IR.Context()) do
mod = if LLVM.version() >= v"15"
IR.load_all_available_dialects()
IR.get_or_load_dialect!(IR.DialectHandle(:func))
parse(IR.Module, """
module {
func.func @f() {
return
}
}
""")
else
IR.get_or_load_dialect!(IR.DialectHandle(:std))
parse(IR.Module, """
module {
func @f() {
std.return
}
}
""")
end
b = IR.body(mod)
ops = collect(IR.OperationIterator(b))
@test ops isa Vector{Operation}
@test length(ops) == 1

op = only(ops)
regions = collect(IR.RegionIterator(op))
@test regions isa Vector{Region}
@test length(regions) == 1

region = only(regions)
blocks = collect(IR.BlockIterator(region))
@test blocks isa Vector{Block}
@test length(blocks) == 1
end
end

0 comments on commit 53536a1

Please sign in to comment.