Skip to content

Commit

Permalink
Generalize blockaxis.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 14, 2024
1 parent 9643287 commit c5c9e1d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 35 deletions.
38 changes: 19 additions & 19 deletions src/blockaxis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function _BlockedUnitRange end
"""
BlockedUnitRange
is an `AbstractUnitRange{Int}` that has been divided into blocks.
is an `AbstractUnitRange{<:Integer}` that has been divided into blocks.
Construction is typically via `blockedrange` which converts
a vector of block lengths to a `BlockedUnitRange`.
Expand Down Expand Up @@ -147,7 +147,7 @@ end
_throw_if_bool(_) = nothing
_throw_if_bool(::Type{Bool}) = throw(ArgumentError("a Bool collection is not allowed as blocklasts"))

const DefaultBlockAxis = BlockedOneTo{Int, Vector{Int}}
const DefaultBlockAxis{T<:Integer} = BlockedOneTo{T, Vector{T}}

first(b::BlockedOneTo) = oneunit(eltype(b))
@inline blocklasts(a::BlockedOneTo) = a.lasts
Expand All @@ -158,9 +158,9 @@ axes(b::BlockedOneTo) = (b,)

"""
blockedrange(blocklengths::Union{Tuple, AbstractVector})
blockedrange(first::Int, blocklengths::Union{Tuple, AbstractVector})
blockedrange(first::Integer, blocklengths::Union{Tuple, AbstractVector})
Return a blocked `AbstractUnitRange{Int}` with the block sizes being `blocklengths`.
Return a blocked `AbstractUnitRange{<:Integer}` with the block sizes being `blocklengths`.
If `first` is provided, this is used as the first value of the range.
Otherwise, if only the block lengths are provided, `first` is assumed to be `1`.
Expand Down Expand Up @@ -197,7 +197,7 @@ end
length(a::AbstractBlockedUnitRange) = isempty(blocklasts(a)) ? zero(eltype(a)) : Integer(last(blocklasts(a))-first(a)+oneunit(eltype(a)))

"""
blockisequal(a::AbstractUnitRange{Int}, b::AbstractUnitRange{Int})
blockisequal(a::AbstractUnitRange{<:Integer}, b::AbstractUnitRange{<:Integer})
Check if `a` and `b` have the same block structure.
Expand Down Expand Up @@ -225,7 +225,7 @@ julia> blockisequal(b1, b2)
false
```
"""
blockisequal(a::AbstractUnitRange{Int}, b::AbstractUnitRange{Int}) = first(a) == first(b) && blocklasts(a) == blocklasts(b)
blockisequal(a::AbstractUnitRange{<:Integer}, b::AbstractUnitRange{<:Integer}) = first(a) == first(b) && blocklasts(a) == blocklasts(b)
blockisequal(a, b, c, d...) = blockisequal(a,b) && blockisequal(b,c,d...)
"""
blockisequal(a::Tuple, b::Tuple)
Expand All @@ -242,20 +242,20 @@ _shift_blocklengths(::AbstractBlockedUnitRange, bl, f) = bl
_shift_blocklengths(::Any, bl, f) = bl .+ (f - 1)
const OneBasedRanges = Union{Base.OneTo, Base.Slice{<:Base.OneTo}, Base.IdentityUnitRange{<:Base.OneTo}}
_shift_blocklengths(::OneBasedRanges, bl, f) = bl
function Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int})
function Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{<:Integer})
bl = blocklasts(axis)
f = first(axis)
_BlockedUnitRange(f, _shift_blocklengths(axis, bl, f))
end
function Base.convert(::Type{BlockedUnitRange{T,CS}}, axis::AbstractUnitRange{Int}) where {T,CS}
function Base.convert(::Type{BlockedUnitRange{T,CS}}, axis::AbstractUnitRange{<:Integer}) where {T,CS}
bl = blocklasts(axis)
f = first(axis)
_BlockedUnitRange(convert(T, f), convert(CS, _shift_blocklengths(axis, bl, f)))
end

Base.unitrange(b::AbstractBlockedUnitRange) = first(b):last(b)

Base.promote_rule(::Type{<:AbstractBlockedUnitRange{T}}, ::Type{Base.OneTo{Int}}) where {T} = UnitRange{promote_type(T, Int)}
Base.promote_rule(::Type{<:AbstractBlockedUnitRange{T}}, ::Type{Base.OneTo{S}}) where {T,S} = UnitRange{promote_type(T, S)}

function Base.convert(::Type{BlockedOneTo}, axis::AbstractUnitRange{<:Integer})
first(axis) == 1 || throw(ArgumentError("first element of range is not 1"))
Expand Down Expand Up @@ -383,7 +383,7 @@ end
_BlockedUnitRange(cs[k-1]+oneunit(eltype(b)),cs[k:j])
end

@propagate_inbounds function getindex(b::AbstractBlockedUnitRange, KR::BlockRange{1,Tuple{Base.OneTo{Int}}})
@propagate_inbounds function getindex(b::AbstractBlockedUnitRange, KR::BlockRange{1,<:Tuple{Base.OneTo{<:Integer}}})
cs = blocklasts(b)
_getindex(b, blocklengths) = _BlockedUnitRange(first(b), blocklengths)
_getindex(b::BlockedOneTo, blocklengths) = BlockedOneTo(blocklengths)
Expand Down Expand Up @@ -415,8 +415,8 @@ Base.dataids(b::AbstractBlockedUnitRange) = Base.dataids(blocklasts(b))
###
# BlockedUnitRange interface
###
Base.checkindex(::Type{Bool}, b::BlockRange, K::Int) = checkindex(Bool, Int.(b), K)
Base.checkindex(::Type{Bool}, b::AbstractUnitRange{Int}, K::Block{1}) = checkindex(Bool, blockaxes(b,1), Int(K))
Base.checkindex(::Type{Bool}, b::BlockRange, K::Integer) = checkindex(Bool, Integer.(b), K)
Base.checkindex(::Type{Bool}, b::AbstractUnitRange{<:Integer}, K::Block{1}) = checkindex(Bool, blockaxes(b,1), Integer(K))

function Base.checkindex(::Type{Bool}, axis::AbstractBlockedUnitRange, ind::BlockIndexRange{1})
checkindex(Bool, axis, first(ind)) && checkindex(Bool, axis, last(ind))
Expand All @@ -425,19 +425,19 @@ function Base.checkindex(::Type{Bool}, axis::AbstractBlockedUnitRange, ind::Bloc
checkindex(Bool, axis, block(ind)) && checkbounds(Bool, axis[block(ind)], blockindex(ind))
end

@propagate_inbounds function getindex(b::AbstractUnitRange{Int}, K::Block{1})
@propagate_inbounds function getindex(b::AbstractUnitRange{<:Integer}, K::Block{1})
@boundscheck K == Block(1) || throw(BlockBoundsError(b, K))
b
end

@propagate_inbounds function getindex(b::AbstractUnitRange{Int}, K::BlockRange)
@propagate_inbounds function getindex(b::AbstractUnitRange{<:Integer}, K::BlockRange)
@boundscheck K == Block.(1:1) || throw(BlockBoundsError(b, K))
b
end

blockaxes(b::AbstractUnitRange{Int}) = (Block.(Base.OneTo(1)),)
blockaxes(b::AbstractUnitRange{T}) where {T<:Integer} = (Block.(Base.OneTo(one(T))),)

function findblock(b::AbstractUnitRange{Int}, k::Integer)
function findblock(b::AbstractUnitRange{<:Integer}, k::Integer)
@boundscheck k in axes(b,1) || throw(BoundsError(b,k))
Block(1)
end
Expand Down Expand Up @@ -545,19 +545,19 @@ Base.BroadcastStyle(::Type{<:AbstractBlockedUnitRange{<:Any,R}}) where R = _broa
# We want to use lazy types when possible
###

const OneToCumsum = RangeCumsum{Int,Base.OneTo{Int}}
const OneToCumsum{T<:Integer} = RangeCumsum{T,Base.OneTo{T}}
sortedunion(a::OneToCumsum, ::OneToCumsum) = a
function sortedunion(a::RangeCumsum{<:Any,<:AbstractRange}, b::RangeCumsum{<:Any,<:AbstractRange})
@assert a == b
a
end

_blocklengths2blocklasts(blocks::AbstractRange) = RangeCumsum(blocks)
function blockfirsts(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{Int}})
function blockfirsts(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{<:Integer}})

Check warning on line 556 in src/blockaxis.jl

View check run for this annotation

Codecov / codecov/patch

src/blockaxis.jl#L556

Added line #L556 was not covered by tests
first(a) == 1 || error("Offset axes not supported")
Base.OneTo{eltype(a)}(length(blocklasts(a)))
end
function blocklengths(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{Int}})
function blocklengths(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{<:Integer}})

Check warning on line 560 in src/blockaxis.jl

View check run for this annotation

Codecov / codecov/patch

src/blockaxis.jl#L560

Added line #L560 was not covered by tests
first(a) == 1 || error("Offset axes not supported")
Ones{eltype(a)}(length(blocklasts(a)))
end
Expand Down
12 changes: 6 additions & 6 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ function _blockarray_print_matrix_row(io::IO,
end
end

function _show_typeof(io::IO, a::BlockVector{T,Vector{Vector{T}},Tuple{DefaultBlockAxis}}) where T
function _show_typeof(io::IO, a::BlockVector{T,Vector{Vector{T}},<:Tuple{DefaultBlockAxis}}) where T
print(io, "BlockVector{")
show(io, T)
print(io, '}')
end

function _show_typeof(io::IO, a::BlockMatrix{T,Matrix{Matrix{T}},NTuple{2,DefaultBlockAxis}}) where T
function _show_typeof(io::IO, a::BlockMatrix{T,Matrix{Matrix{T}},<:Tuple{Vararg{DefaultBlockAxis,2}}}) where T
print(io, "BlockMatrix{")
show(io, T)
print(io, '}')
end

function _show_typeof(io::IO, a::BlockArray{T,N,Array{Array{T,N},N},NTuple{N,DefaultBlockAxis}}) where {T,N}
function _show_typeof(io::IO, a::BlockArray{T,N,Array{Array{T,N},N},<:Tuple{Vararg{DefaultBlockAxis,N}}}) where {T,N}
Base.show_type_name(io, typeof(a).name)
print(io, '{')
show(io, T)
Expand All @@ -122,19 +122,19 @@ axes_print_matrix_row(::Tuple{AbstractBlockedUnitRange,AbstractUnitRange}, io, X
Base.print_matrix_row(io::IO, X::AbstractBlockedUnitRange, A::Vector, i::Integer, cols::AbstractVector, sep::AbstractString, idxlast::Integer=last(axes(X, 2))) =
_blockarray_print_matrix_row(io, X, A, i, cols, sep)

function _show_typeof(io::IO, a::BlockedVector{T,Vector{T},Tuple{DefaultBlockAxis}}) where T
function _show_typeof(io::IO, a::BlockedVector{T,Vector{T},<:Tuple{DefaultBlockAxis}}) where T
print(io, "BlockedVector{")
show(io, T)
print(io, '}')
end

function _show_typeof(io::IO, a::BlockedMatrix{T,Matrix{T},NTuple{2,DefaultBlockAxis}}) where T
function _show_typeof(io::IO, a::BlockedMatrix{T,Matrix{T},<:Tuple{Vararg{DefaultBlockAxis,2}}}) where T
print(io, "BlockedMatrix{")
show(io, T)
print(io, '}')
end

function _show_typeof(io::IO, a::BlockedArray{T,N,Array{T,N},NTuple{N,DefaultBlockAxis}}) where {T,N}
function _show_typeof(io::IO, a::BlockedArray{T,N,Array{T,N},<:Tuple{Vararg{DefaultBlockAxis,N}}}) where {T,N}
Base.show_type_name(io, typeof(a).name)
print(io, '{')
show(io, T)
Expand Down
22 changes: 12 additions & 10 deletions test/test_blockindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,18 @@ end
end

@testset "convert" begin
b = blockedrange(1, Fill(2,3))
c = blockedrange(1, [2,2,2])
@test oftype(b, b) === b
@test blockisequal(convert(BlockedUnitRange, Base.OneTo(5)), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, Base.Slice(Base.OneTo(5))), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, Base.IdentityUnitRange(-2:2)), BlockArrays._BlockedUnitRange(-2,[2]))
@test convert(BlockedUnitRange{Int,Vector{Int}}, c) === c
@test blockisequal(convert(BlockedUnitRange{Int,Vector{Int}}, b),b)
@test blockisequal(convert(BlockedUnitRange{Int,Vector{Int}}, Base.OneTo(5)), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, BlockedOneTo(1:3)), blockedrange(1, [1,1,1]))
for elt in (Int, UInt)
b = blockedrange(elt(1), Fill(elt(2),3))
c = blockedrange(elt(1), elt[2,2,2])
@test oftype(b, b) === b
@test blockisequal(convert(BlockedUnitRange, Base.OneTo(5)), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, Base.Slice(Base.OneTo(5))), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, Base.IdentityUnitRange(-2:2)), BlockArrays._BlockedUnitRange(-2,[2]))
@test convert(BlockedUnitRange{elt,Vector{elt}}, c) === c
@test blockisequal(convert(BlockedUnitRange{Int,Vector{Int}}, b),b)
@test blockisequal(convert(BlockedUnitRange{Int,Vector{Int}}, Base.OneTo(5)), blockedrange(1, [5]))
@test blockisequal(convert(BlockedUnitRange, BlockedOneTo(1:3)), blockedrange(1, [1,1,1]))
end
end

@testset "findblock" begin
Expand Down

0 comments on commit c5c9e1d

Please sign in to comment.