Skip to content

Commit 150f55c

Browse files
authored
Consolidate convert methods for BlockedUnitRange (#349)
1 parent f87ad0c commit 150f55c

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

src/blockaxis.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,20 @@ Return `all(blockisequal.(a,b))``
117117
blockisequal(a::Tuple, b::Tuple) = all(blockisequal.(a, b))
118118

119119

120-
Base.convert(::Type{BlockedUnitRange}, axis::BlockedUnitRange) = axis
121-
Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int}) = _BlockedUnitRange(first(axis),[last(axis)])
122-
Base.convert(::Type{BlockedUnitRange}, axis::Base.Slice) = _BlockedUnitRange(first(axis),[last(axis)])
123-
Base.convert(::Type{BlockedUnitRange}, axis::Base.IdentityUnitRange) = _BlockedUnitRange(first(axis),[last(axis)])
124-
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange{CS}) where CS = axis
125-
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange) where CS = _BlockedUnitRange(first(axis), convert(CS, blocklasts(axis)))
126-
Base.convert(::Type{BlockedUnitRange{CS}}, axis::AbstractUnitRange{Int}) where CS = convert(BlockedUnitRange{CS}, convert(BlockedUnitRange, axis))
120+
_shift_blocklengths(::BlockedUnitRange, bl, f) = bl
121+
_shift_blocklengths(::Any, bl, f) = bl .+ (f - 1)
122+
const OneBasedRanges = Union{Base.OneTo, Base.Slice{<:Base.OneTo}, Base.IdentityUnitRange{<:Base.OneTo}}
123+
_shift_blocklengths(::OneBasedRanges, bl, f) = bl
124+
function Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int})
125+
bl = blocklasts(axis)
126+
f = first(axis)
127+
_BlockedUnitRange(f, _shift_blocklengths(axis, bl, f))
128+
end
129+
function Base.convert(::Type{BlockedUnitRange{CS}}, axis::AbstractUnitRange{Int}) where CS
130+
bl = blocklasts(axis)
131+
f = first(axis)
132+
_BlockedUnitRange(f, convert(CS, _shift_blocklengths(axis, bl, f)))
133+
end
127134

128135
Base.unitrange(b::BlockedUnitRange) = first(b):last(b)
129136

test/test_blockindices.jl

+17-6
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,23 @@ end
215215
b = blockedrange(Fill(2,3))
216216
c = blockedrange([2,2,2])
217217
@test convert(BlockedUnitRange, b) === b
218-
@test blockisequal(convert(BlockedUnitRange, Base.OneTo(5)), blockedrange([5]))
219-
@test blockisequal(convert(BlockedUnitRange, Base.Slice(Base.OneTo(5))), blockedrange([5]))
220-
@test blockisequal(convert(BlockedUnitRange, Base.IdentityUnitRange(-2:2)), BlockArrays._BlockedUnitRange(-2,[2]))
221-
@test convert(BlockedUnitRange{Vector{Int}}, c) === c
222-
@test blockisequal(convert(BlockedUnitRange{Vector{Int}}, b),b)
223-
@test blockisequal(convert(BlockedUnitRange{Vector{Int}}, Base.OneTo(5)), blockedrange([5]))
218+
@test convert(typeof(b), b) === b
219+
@test convert(BlockedUnitRange, c) === c
220+
@test convert(typeof(c), c) === c
221+
function test_type_and_blocks(T, r, res)
222+
s = convert(T, r)
223+
@test s isa T
224+
@test blockisequal(s, res)
225+
end
226+
for T in (BlockedUnitRange, BlockedUnitRange{Vector{Int}})
227+
test_type_and_blocks(T, blockedrange(5:5), blockedrange(5:5))
228+
test_type_and_blocks(T, Base.OneTo(5), blockedrange([5]))
229+
test_type_and_blocks(T, Base.Slice(Base.OneTo(5)), blockedrange([5]))
230+
test_type_and_blocks(T, -2:2, BlockArrays._BlockedUnitRange(-2,[2]))
231+
test_type_and_blocks(T, Base.IdentityUnitRange(-2:2), BlockArrays._BlockedUnitRange(-2,[2]))
232+
test_type_and_blocks(T, b, b)
233+
test_type_and_blocks(T, Base.OneTo(5), blockedrange([5]))
234+
end
224235
end
225236

226237
@testset "findblock" begin

0 commit comments

Comments
 (0)