Skip to content

Commit

Permalink
Consolidate convert methods for BlockedUnitRange
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Mar 22, 2024
1 parent f87ad0c commit 437829f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
21 changes: 14 additions & 7 deletions src/blockaxis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,20 @@ Return `all(blockisequal.(a,b))``
blockisequal(a::Tuple, b::Tuple) = all(blockisequal.(a, b))


Base.convert(::Type{BlockedUnitRange}, axis::BlockedUnitRange) = axis
Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int}) = _BlockedUnitRange(first(axis),[last(axis)])
Base.convert(::Type{BlockedUnitRange}, axis::Base.Slice) = _BlockedUnitRange(first(axis),[last(axis)])
Base.convert(::Type{BlockedUnitRange}, axis::Base.IdentityUnitRange) = _BlockedUnitRange(first(axis),[last(axis)])
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange{CS}) where CS = axis
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange) where CS = _BlockedUnitRange(first(axis), convert(CS, blocklasts(axis)))
Base.convert(::Type{BlockedUnitRange{CS}}, axis::AbstractUnitRange{Int}) where CS = convert(BlockedUnitRange{CS}, convert(BlockedUnitRange, axis))
_shift_blocklengths(::BlockedUnitRange, bl, f) = bl
_shift_blocklengths(::Any, bl, f) = bl .+ (f - 1)
const OneBasedRanges = Union{Base.OneTo, Base.Slice{<:Base.OneTo}, Base.IdentityUnitRange{<:Base.OneTo}}
_shift_blocklengths(::OneBasedRanges, bl, f) = bl
function Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int})
bl = blocklasts(axis)
f = first(axis)
_BlockedUnitRange(f, _shift_blocklengths(axis, bl, f))
end
function Base.convert(::Type{BlockedUnitRange{CS}}, axis::AbstractUnitRange{Int}) where CS
bl = blocklasts(axis)
f = first(axis)
_BlockedUnitRange(f, convert(CS, _shift_blocklengths(axis, bl, f)))
end

Base.unitrange(b::BlockedUnitRange) = first(b):last(b)

Expand Down
23 changes: 17 additions & 6 deletions test/test_blockindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,23 @@ end
b = blockedrange(Fill(2,3))
c = blockedrange([2,2,2])
@test convert(BlockedUnitRange, b) === b
@test blockisequal(convert(BlockedUnitRange, Base.OneTo(5)), blockedrange([5]))
@test blockisequal(convert(BlockedUnitRange, Base.Slice(Base.OneTo(5))), blockedrange([5]))
@test blockisequal(convert(BlockedUnitRange, Base.IdentityUnitRange(-2:2)), BlockArrays._BlockedUnitRange(-2,[2]))
@test convert(BlockedUnitRange{Vector{Int}}, c) === c
@test blockisequal(convert(BlockedUnitRange{Vector{Int}}, b),b)
@test blockisequal(convert(BlockedUnitRange{Vector{Int}}, Base.OneTo(5)), blockedrange([5]))
@test convert(typeof(b), b) === b
@test convert(BlockedUnitRange, c) === c
@test convert(typeof(c), c) === c
function test_type_and_blocks(T, r, res)
s = convert(T, r)
@test s isa T
@test blockisequal(s, res)
end
for T in (BlockedUnitRange, BlockedUnitRange{Vector{Int}})
test_type_and_blocks(T, blockedrange(5:5), blockedrange(5:5))
test_type_and_blocks(T, Base.OneTo(5), blockedrange([5]))
test_type_and_blocks(T, Base.Slice(Base.OneTo(5)), blockedrange([5]))
test_type_and_blocks(T, -2:2, BlockArrays._BlockedUnitRange(-2,[2]))
test_type_and_blocks(T, Base.IdentityUnitRange(-2:2), BlockArrays._BlockedUnitRange(-2,[2]))
test_type_and_blocks(T, b, b)
test_type_and_blocks(T, Base.OneTo(5), blockedrange([5]))
end
end

@testset "findblock" begin
Expand Down

0 comments on commit 437829f

Please sign in to comment.