diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index cb6aab1fa..170da63bb 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -4,7 +4,7 @@ using Requires using LinearAlgebra using SparseArrays -using Base: OneTo +using Base: OneTo, @propagate_inbounds Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) @@ -543,6 +543,103 @@ function restructure(x::Array,y) reshape(convert(Array,y),size(x)...) end +""" + insert(collection, index, item) + +Return a new instance of `collection` with `item` inserted into at the given `index`. +""" +Base.@propagate_inbounds function insert(collection, index, item) + @boundscheck checkbounds(collection, index) + ret = similar(collection, length(collection) + 1) + @inbounds for i in firstindex(ret):(index - 1) + ret[i] = collection[i] + end + @inbounds ret[index] = item + @inbounds for i in (index + 1):lastindex(ret) + ret[i] = collection[i - 1] + end + return ret +end + +function insert(x::Tuple, index::Integer, item) + @boundscheck if !checkindex(Bool, static_first(x):static_last(x), index) + throw(BoundsError(x, index)) + end + return unsafe_insert(x, Int(index), item) +end + +@inline function unsafe_insert(x::Tuple, i::Int, item) + if i === 1 + return (item, x...) + else + return (first(x), unsafe_insert(Base.tail(x), i - 1, item)...) + end +end + +""" + deleteat(collection, index) + +Return a new instance of `collection` with the item at the given `index` removed. +""" +@propagate_inbounds function deleteat(collection::AbstractVector, index) + @boundscheck if !checkindex(Bool, eachindex(collection), index) + throw(BoundsError(collection, index)) + end + return unsafe_deleteat(collection, index) +end +@propagate_inbounds function deleteat(collection::Tuple, index) + @boundscheck if !checkindex(Bool, static_first(collection):static_last(collection), index) + throw(BoundsError(collection, index)) + end + return unsafe_deleteat(collection, index) +end + +function unsafe_deleteat(src::AbstractVector, index::Integer) + dst = similar(src, length(src) - 1) + @inbounds for i in indices(dst) + if i < index + dst[i] = src[i] + else + dst[i] = src[i + 1] + end + end + return dst +end + +@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector) + dst = similar(src, length(src) - length(inds)) + dst_index = firstindex(dst) + @inbounds for src_index in indices(src) + if !in(src_index, inds) + dst[dst_index] = src[src_index] + dst_index += one(dst_index) + end + end + return dst +end + +@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector) + dst = Vector{eltype(src)}(undef, length(src) - length(inds)) + dst_index = firstindex(dst) + @inbounds for src_index in OneTo(length(src)) + if !in(src_index, inds) + dst[dst_index] = src[src_index] + dst_index += one(dst_index) + end + end + return Tuple(dst) +end + +@inline function unsafe_deleteat(x::Tuple, i::Integer) + if i === one(i) + return Base.tail(x) + elseif i == length(x) + return Base.front(x) + else + return (first(x), unsafe_deleteat(Base.tail(x), i - one(i))...) + end +end + function __init__() @require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin diff --git a/src/ranges.jl b/src/ranges.jl index 601c8455e..84ce7167c 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -42,12 +42,6 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T) # add methods to support ArrayInterface -_get(x) = x -_get(::Static{V}) where {V} = V -_get(::Type{Static{V}}) where {V} = V -_convert(::Type{T}, x) where {T} = convert(T, x) -_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V)) - """ OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T} @@ -57,28 +51,23 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in from other valid indices. Therefore, users should not expect the same checks are used to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`. """ -struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T} +struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRange{Int} start::F stop::L - function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real} - if _get(start) isa T - if _get(stop) isa T - return new{T,typeof(start),typeof(stop)}(start, stop) + function OptionallyStaticUnitRange(start, stop) + if eltype(start) <: Int + if eltype(stop) <: Int + return new{typeof(start),typeof(stop)}(start, stop) else - return OptionallyStaticUnitRange{T}(start, _convert(T, stop)) + return OptionallyStaticUnitRange(start, Int(stop)) end else - return OptionallyStaticUnitRange{T}(_convert(T, start), stop) + return OptionallyStaticUnitRange(Int(start), stop) end end - function OptionallyStaticUnitRange(start, stop) - T = promote_type(typeof(_get(start)), typeof(_get(stop))) - return OptionallyStaticUnitRange{T}(start, stop) - end - - function OptionallyStaticUnitRange(x::AbstractRange) + function OptionallyStaticUnitRange(x::AbstractRange) if step(x) == 1 fst = static_first(x) lst = static_last(x) @@ -94,12 +83,12 @@ Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static( Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U)) Base.first(r::OptionallyStaticUnitRange) = r.start -Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T) +Base.step(::OptionallyStaticUnitRange) = Static(1) Base.last(r::OptionallyStaticUnitRange) = r.stop -known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F -known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T) -known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L +known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F +known_step(::Type{<:OptionallyStaticUnitRange}) = 1 +known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L function Base.isempty(r::OptionallyStaticUnitRange) if known_first(r) === oneunit(eltype(r)) @@ -112,10 +101,8 @@ end unsafe_isempty_one_to(lst) = lst <= zero(lst) unsafe_isempty_unit_range(fst, lst) = fst > lst -unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T)) - -unsafe_length_one_to(lst::T) where {T<:Int} = T(lst) -unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst)) +unsafe_length_one_to(lst::Int) = lst +unsafe_length_one_to(::Static{L}) where {L} = lst Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer) if known_first(r) === oneunit(r) @@ -144,15 +131,15 @@ end @inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()" function _try_static(::Static{N}, x) where {N} @assert N == x "Unequal Indices: Static{$N}() != x == $x" - Static{N}() + return Static{N}() end function _try_static(x, ::Static{N}) where {N} @assert N == x "Unequal Indices: x == $x != Static{$N}()" - Static{N}() + return Static{N}() end function _try_static(x, y) @assert x == y "Unequal Indicess: x == $x != $y == y" - x + return x end ### @@ -172,11 +159,11 @@ end end end -function Base.length(r::OptionallyStaticUnitRange{T}) where {T} +function Base.length(r::OptionallyStaticUnitRange) if isempty(r) - return zero(T) + return 0 else - if known_one(r) === one(T) + if known_first(r) === 0 return unsafe_length_one_to(last(r)) else return unsafe_length_unit_range(first(r), last(r)) @@ -184,12 +171,7 @@ function Base.length(r::OptionallyStaticUnitRange{T}) where {T} end end -function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}} - return Base.checked_add(Base.checked_sub(lst, fst), one(T)) -end -function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}} - return Base.checked_add(lst - fst, one(T)) -end +unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1) """ indices(x[, d]) @@ -231,4 +213,3 @@ end lst = _try_static(static_last(x), static_last(y)) return Base.Slice(OptionallyStaticUnitRange(fst, lst)) end - diff --git a/src/static.jl b/src/static.jl index 6a90cecd5..013dbcdb9 100644 --- a/src/static.jl +++ b/src/static.jl @@ -6,6 +6,10 @@ Use `Static(N)` instead of `Val(N)` when you want it to behave like a number. struct Static{N} <: Integer Static{N}() where {N} = new{N::Int}() end + +const Zero = Static{0} +const One = Static{1} + Base.@pure Static(N::Int) = Static{N}() Static(N::Integer) = Static(convert(Int, N)) Static(::Static{N}) where {N} = Static{N}() @@ -33,41 +37,44 @@ end Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N -Base.iszero(::Static{0}) = true +Base.eltype(::Type{T}) where {T<:Static} = Int +Base.iszero(::Zero) = true Base.iszero(::Static) = false -Base.isone(::Static{1}) = true +Base.isone(::One) = true Base.isone(::Static) = false +Base.zero(::Type{T}) where {T<:Static} = Zero() +Base.one(::Type{T}) where {T<:Static} = One() for T = [:Real, :Rational, :Integer] @eval begin - @inline Base.:(+)(i::$T, ::Static{0}) = i + @inline Base.:(+)(i::$T, ::Zero) = i @inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M - @inline Base.:(+)(::Static{0}, i::$T) = i + @inline Base.:(+)(::Zero, i::$T) = i @inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i - @inline Base.:(-)(i::$T, ::Static{0}) = i + @inline Base.:(-)(i::$T, ::Zero) = i @inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M - @inline Base.:(*)(i::$T, ::Static{0}) = Static{0}() - @inline Base.:(*)(i::$T, ::Static{1}) = i + @inline Base.:(*)(i::$T, ::Zero) = Zero() + @inline Base.:(*)(i::$T, ::One) = i @inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M - @inline Base.:(*)(::Static{0}, i::$T) = Static{0}() - @inline Base.:(*)(::Static{1}, i::$T) = i + @inline Base.:(*)(::Zero, i::$T) = Zero() + @inline Base.:(*)(::One, i::$T) = i @inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i end end -@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}() -@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}() -@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}() +@inline Base.:(+)(::Zero, ::Zero) = Zero() +@inline Base.:(+)(::Zero, ::Static{M}) where {M} = Static{M}() +@inline Base.:(+)(::Static{M}, ::Zero) where {M} = Static{M}() -@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}() +@inline Base.:(-)(::Static{M}, ::Zero) where {M} = Static{M}() -@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}() -@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}() -@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}() -@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}() -@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}() -@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}() -@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}() -@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}() +@inline Base.:(*)(::Zero, ::Zero) = Zero() +@inline Base.:(*)(::One, ::Zero) = Zero() +@inline Base.:(*)(::Zero, ::One) = Zero() +@inline Base.:(*)(::One, ::One) = One() +@inline Base.:(*)(::Static{M}, ::Zero) where {M} = Zero() +@inline Base.:(*)(::Zero, ::Static{M}) where {M} = Zero() +@inline Base.:(*)(::Static{M}, ::One) where {M} = Static{M}() +@inline Base.:(*)(::One, ::Static{M}) where {M} = Static{M}() for f ∈ [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)] @eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N))) end diff --git a/test/runtests.jl b/test/runtests.jl index 7e3e839f0..d61ba30d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -252,6 +252,9 @@ end @testset "Static" begin @test iszero(Static(0)) @test !iszero(Static(1)) + @test @inferred(one(Static)) === Static(1) + @test @inferred(zero(Static)) === Static(0) + @test eltype(one(Static)) <: Int # test for ambiguities and correctness for i ∈ [Static(0), Static(1), Static(2), 3] for j ∈ [Static(0), Static(1), Static(2), 3] @@ -271,3 +274,22 @@ end end end +@testset "insert/deleteat" begin + @test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3] + @test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3] + + @test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 2])) == [3] + @test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2] + @test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1] + + + @test @inferred(ArrayInterface.insert((1,2,3), 1, -2)) == (-2, 1, 2, 3) + @test @inferred(ArrayInterface.insert((1,2,3), 2, -2)) == (1, -2, 2, 3) + @test @inferred(ArrayInterface.insert((1,2,3), 3, -2)) == (1, 2, -2, 3) + + @test @inferred(ArrayInterface.deleteat((1, 2, 3), 1)) == (2, 3) + @test @inferred(ArrayInterface.deleteat((1, 2, 3), 2)) == (1, 3) + @test @inferred(ArrayInterface.deleteat((1, 2, 3), 3)) == (1, 2) + @test ArrayInterface.deleteat((1, 2, 3), [1, 2]) == (3,) +end +