Skip to content

Commit

Permalink
Call functions instead of field access for BlockedUnitRange (#338)
Browse files Browse the repository at this point in the history
* Call functions instead of field access for BlockedUnitRange

* Add tests for Zeros
  • Loading branch information
jishnub authored Mar 20, 2024
1 parent 7e31cf8 commit 238fed8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/blockaxis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,21 @@ BlockedUnitRange(::BlockedUnitRange) = throw(ArgumentError("Forbidden due to amb
_blocklengths2blocklasts(blocks) = cumsum(blocks) # extra level to allow changing default cumsum behaviour
@inline blockedrange(blocks::Union{Tuple,AbstractVector}) = _BlockedUnitRange(_blocklengths2blocklasts(blocks))

@inline blockfirsts(a::BlockedUnitRange) = [a.first; @views(a.lasts[1:end-1]) .+ 1]
@inline blockfirsts(a::BlockedUnitRange) = [first(a); @views(blocklasts(a)[1:end-1]) .+ 1]
# optimize common cases
@inline function blockfirsts(a::BlockedUnitRange{<:Union{Vector, RangeCumsum{<:Any, <:UnitRange}}})
v = Vector{eltype(a)}(undef, length(a.lasts))
v[1] = a.first
v[2:end] .= @views(a.lasts[oneto(end-1)]) .+ 1
v = Vector{eltype(a)}(undef, length(blocklasts(a)))
v[1] = first(a)
v[2:end] .= @views(blocklasts(a)[oneto(end-1)]) .+ 1
return v
end
@inline blocklasts(a::BlockedUnitRange) = a.lasts

_diff(a::AbstractVector) = diff(a)
_diff(a::Tuple) = diff(collect(a))
@inline blocklengths(a::BlockedUnitRange) = isempty(a.lasts) ? [_diff(a.lasts);] : [first(a.lasts)-a.first+1; _diff(a.lasts)]
@inline blocklengths(a::BlockedUnitRange) = isempty(blocklasts(a)) ? [_diff(blocklasts(a));] : [first(blocklasts(a))-first(a)+1; _diff(blocklasts(a))]

length(a::BlockedUnitRange) = isempty(a.lasts) ? 0 : Integer(last(a.lasts)-a.first+1)
length(a::BlockedUnitRange) = isempty(blocklasts(a)) ? 0 : Integer(last(blocklasts(a))-first(a)+1)

"""
blockisequal(a::AbstractUnitRange{Int}, b::AbstractUnitRange{Int})
Expand Down Expand Up @@ -440,24 +440,24 @@ Base.BroadcastStyle(::Type{BlockedUnitRange{R}}) where R = Base.BroadcastStyle(R

_blocklengths2blocklasts(blocks::AbstractRange) = RangeCumsum(blocks)
function blockfirsts(a::BlockedUnitRange{Base.OneTo{Int}})
a.first == 1 || error("Offset axes not supported")
Base.OneTo{Int}(length(a.lasts))
first(a) == 1 || error("Offset axes not supported")
Base.OneTo{Int}(length(blocklasts(a)))
end
function blocklengths(a::BlockedUnitRange{Base.OneTo{Int}})
a.first == 1 || error("Offset axes not supported")
Ones{Int}(length(a.lasts))
first(a) == 1 || error("Offset axes not supported")
Ones{Int}(length(blocklasts(a)))
end
function blockfirsts(a::BlockedUnitRange{<:AbstractRange})
st = step(a.lasts)
a.first == 1 || error("Offset axes not supported")
@assert first(a.lasts)-a.first+1 == st
range(1; step=st, length=length(a.lasts))
st = step(blocklasts(a))
first(a) == 1 || error("Offset axes not supported")
@assert first(blocklasts(a))-first(a)+1 == st
range(1; step=st, length=length(blocklasts(a)))
end
function blocklengths(a::BlockedUnitRange{<:AbstractRange})
st = step(a.lasts)
a.first == 1 || error("Offset axes not supported")
@assert first(a.lasts)-a.first+1 == st
Fill(st,length(a.lasts))
st = step(blocklasts(a))
first(a) == 1 || error("Offset axes not supported")
@assert first(blocklasts(a))-first(a)+1 == st
Fill(st,length(blocklasts(a)))
end


Expand Down
4 changes: 4 additions & 0 deletions test/test_blockindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ end
@test blocklasts(f) StepRangeLen(2,2,5)
@test blocklengths(f) Fill(2,5)

f = blockedrange(Zeros{Int}(2))
@test blockfirsts(f) == [1,1]
@test blocklasts(f) == [0,0]

r = blockedrange(Base.OneTo(5))
@test (@inferred blocklengths(r)) == 1:5
@test blocklasts(r) ArrayLayouts.RangeCumsum(Base.OneTo(5))
Expand Down

0 comments on commit 238fed8

Please sign in to comment.