Skip to content

Commit 53536a1

Browse files
committed
1 parent 3527e24 commit 53536a1

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/IR/Iterators.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ struct BlockIterator
77
region::Region
88
end
99

10+
Base.IteratorSize(::Core.Type{BlockIterator}) = Base.SizeUnknown()
11+
Base.eltype(::BlockIterator) = Block
12+
1013
function Base.iterate(it::BlockIterator)
1114
reg = it.region
1215
raw_block = API.mlirRegionGetFirstBlock(reg)
@@ -37,6 +40,9 @@ struct RegionIterator
3740
op::Operation
3841
end
3942

43+
Base.eltype(::RegionIterator) = Region
44+
Base.length(it::RegionIterator) = nregions(it.op)
45+
4046
function Base.iterate(it::RegionIterator)
4147
raw_region = API.mlirOperationGetFirstRegion(it.op)
4248
if mlirIsNull(raw_region)
@@ -66,6 +72,9 @@ struct OperationIterator
6672
block::Block
6773
end
6874

75+
Base.IteratorSize(::Core.Type{OperationIterator}) = Base.SizeUnknown()
76+
Base.eltype(::OperationIterator) = Operation
77+
6978
function Base.iterate(it::OperationIterator)
7079
raw_op = API.mlirBlockGetFirstOperation(it.block)
7180
if mlirIsNull(raw_op)

test/ir.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,42 @@ end
3131
@test_throws AssertionError IR.Module(arith.constant(; value=true, result=IR.Type(Bool)))
3232
end
3333
end
34+
35+
@testset "Iterators" begin
36+
IR.context!(IR.Context()) do
37+
mod = if LLVM.version() >= v"15"
38+
IR.load_all_available_dialects()
39+
IR.get_or_load_dialect!(IR.DialectHandle(:func))
40+
parse(IR.Module, """
41+
module {
42+
func.func @f() {
43+
return
44+
}
45+
}
46+
""")
47+
else
48+
IR.get_or_load_dialect!(IR.DialectHandle(:std))
49+
parse(IR.Module, """
50+
module {
51+
func @f() {
52+
std.return
53+
}
54+
}
55+
""")
56+
end
57+
b = IR.body(mod)
58+
ops = collect(IR.OperationIterator(b))
59+
@test ops isa Vector{Operation}
60+
@test length(ops) == 1
61+
62+
op = only(ops)
63+
regions = collect(IR.RegionIterator(op))
64+
@test regions isa Vector{Region}
65+
@test length(regions) == 1
66+
67+
region = only(regions)
68+
blocks = collect(IR.BlockIterator(region))
69+
@test blocks isa Vector{Block}
70+
@test length(blocks) == 1
71+
end
72+
end

0 commit comments

Comments
 (0)