From eed0ed0dd814f0ea881b02cfe907c64c7bc9f70e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 26 Mar 2024 11:00:15 +0530 Subject: [PATCH] Simplify state in SubBlockIterator --- src/blockbroadcast.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index 7bf796b9..e1449afe 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -49,7 +49,7 @@ similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} = SubBlockIterator(subblock_lasts::Vector{Int}, block_lasts::Vector{Int}) SubBlockIterator(A::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}} where N, dim::Integer) -An iterator for iterating `BlockIndexRange` of the blocks specified by +Return an iterator over the `BlockIndexRange`s of the blocks specified by `subblock_lasts`. The `Block` index part of `BlockIndexRange` is determined by `subblock_lasts`. That is to say, the `Block` index first specifies one of the block represented by `subblock_lasts` and then the @@ -85,7 +85,7 @@ view(A, idx) = 2:3 view(A, idx) = 4:4 view(A, idx) = 5:6 -julia> [idx.block.n[1] for idx in SubBlockIterator(subblock_lasts, block_lasts)] +julia> [Int(idx.block) for idx in SubBlockIterator(subblock_lasts, block_lasts)] 4-element Vector{Int64}: 1 2 @@ -114,12 +114,8 @@ Base.length(it::SubBlockIterator) = length(it.block_lasts) SubBlockIterator(arr::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}}, dim::Integer) where N = SubBlockIterator(blocklasts(axes(arr, dim)), blocklasts(bs[dim])) -function Base.iterate(it::SubBlockIterator, state=nothing) - if state === nothing - i,j = 1,1 - else - i, j = state - end +function Base.iterate(it::SubBlockIterator, state=(1,1)) + i, j = state length(it.block_lasts)+1 == i && return nothing idx = i == 1 ? (1:it.block_lasts[i]) : (it.block_lasts[i-1]+1:it.block_lasts[i])