From 53536a1e06b9db0614aeeeebb7e08d608ac77ae8 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 13 Mar 2024 21:50:10 +0100 Subject: [PATCH] Complete Iterator interface implementation Ref: https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration --- src/IR/Iterators.jl | 9 +++++++++ test/ir.jl | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/IR/Iterators.jl b/src/IR/Iterators.jl index 5b17184d..342d666f 100644 --- a/src/IR/Iterators.jl +++ b/src/IR/Iterators.jl @@ -7,6 +7,9 @@ struct BlockIterator region::Region end +Base.IteratorSize(::Core.Type{BlockIterator}) = Base.SizeUnknown() +Base.eltype(::BlockIterator) = Block + function Base.iterate(it::BlockIterator) reg = it.region raw_block = API.mlirRegionGetFirstBlock(reg) @@ -37,6 +40,9 @@ struct RegionIterator op::Operation end +Base.eltype(::RegionIterator) = Region +Base.length(it::RegionIterator) = nregions(it.op) + function Base.iterate(it::RegionIterator) raw_region = API.mlirOperationGetFirstRegion(it.op) if mlirIsNull(raw_region) @@ -66,6 +72,9 @@ struct OperationIterator block::Block end +Base.IteratorSize(::Core.Type{OperationIterator}) = Base.SizeUnknown() +Base.eltype(::OperationIterator) = Operation + function Base.iterate(it::OperationIterator) raw_op = API.mlirBlockGetFirstOperation(it.block) if mlirIsNull(raw_op) diff --git a/test/ir.jl b/test/ir.jl index 3aac1ac5..c47ff4b3 100644 --- a/test/ir.jl +++ b/test/ir.jl @@ -31,3 +31,42 @@ end @test_throws AssertionError IR.Module(arith.constant(; value=true, result=IR.Type(Bool))) end end + +@testset "Iterators" begin + IR.context!(IR.Context()) do + mod = if LLVM.version() >= v"15" + IR.load_all_available_dialects() + IR.get_or_load_dialect!(IR.DialectHandle(:func)) + parse(IR.Module, """ + module { + func.func @f() { + return + } + } + """) + else + IR.get_or_load_dialect!(IR.DialectHandle(:std)) + parse(IR.Module, """ + module { + func @f() { + std.return + } + } + """) + end + b = IR.body(mod) + ops = collect(IR.OperationIterator(b)) + @test ops isa Vector{Operation} + @test length(ops) == 1 + + op = only(ops) + regions = collect(IR.RegionIterator(op)) + @test regions isa Vector{Region} + @test length(regions) == 1 + + region = only(regions) + blocks = collect(IR.BlockIterator(region)) + @test blocks isa Vector{Block} + @test length(blocks) == 1 + end +end