Skip to content

Commit

Permalink
Merge pull request #69 from SciML/static
Browse files Browse the repository at this point in the history
Added Static integers.
  • Loading branch information
chriselrod authored Sep 9, 2020
2 parents c83328b + 581b18a commit 85e93de
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 38 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Requires = "0.5, 1.0"
julia = "1.2"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand All @@ -21,4 +22,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random"]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"]
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns
If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

## Static(N::Int)

Creates a static integer with value known at compile time. It is a number,
supporting basic arithmetic. Many operations with two `Static` integers
will produce another `Static` integer. If one of the arguments to a
function call isn't static (e.g., `Static(4) + 3`) then the `Static`
number will promote to a dynamic value.

# List of things to add

- https://github.com/JuliaLang/julia/issues/22216
Expand Down
1 change: 1 addition & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ function __init__()
end
end

include("static.jl")
include("ranges.jl")

end
61 changes: 32 additions & 29 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
# add methods to support ArrayInterface

_get(x) = x
_get(::Val{V}) where {V} = V
_get(::Static{V}) where {V} = V
_get(::Type{Static{V}}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

Expand All @@ -56,7 +57,7 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
start::F
stop::L

Expand All @@ -79,28 +80,26 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}

function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = known_first(x)
fst = fst === nothing ? first(x) : Val(fst)
lst = known_last(x)
lst = lst === nothing ? last(x) : Val(lst)
fst = static_first(x)
lst = static_last(x)
return OptionallyStaticUnitRange(fst, lst)
else
throw(ArgumentError("step must be 1, got $(step(r))"))
end
end
end

Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start
Base.:(:)(L::Integer, ::Static{U}) where {U} = OptionallyStaticUnitRange(L, Static(U))
Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(L), U)
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))

Base.first(r::OptionallyStaticUnitRange) = r.start
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
Base.last(r::OptionallyStaticUnitRange) = r.stop

Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
Expand Down Expand Up @@ -141,10 +140,20 @@ end
return convert(eltype(r), val)
end

_try_static(x, y) = Val(x)
_try_static(::Nothing, y) = Val(y)
_try_static(x, ::Nothing) = Val(x)
_try_static(::Nothing, ::Nothing) = nothing
@inline _try_static(::Static{N}, ::Static{N}) where {N} = Static{N}()
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
function _try_static(::Static{N}, x) where {N}
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
Static{N}()
end
function _try_static(x, ::Static{N}) where {N}
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
Static{N}()
end
function _try_static(x, y)
@assert x == y "Unequal Indicess: x == $x != $y == y"
x
end

###
### length
Expand Down Expand Up @@ -193,7 +202,7 @@ specified then indices for visiting each index of `x` is returned.
"""
@inline function indices(x)
inds = eachindex(x)
if inds isa AbstractUnitRange{<:Integer}
if inds isa AbstractUnitRange && eltype(inds) <: Integer
return Base.Slice(OptionallyStaticUnitRange(inds))
else
return inds
Expand All @@ -202,30 +211,24 @@ end

function indices(x::Tuple)
inds = map(eachindex, x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

indices(x, d) = indices(axes(x, d))
@inline indices(x, d) = indices(axes(x, d))

@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
inds = map(x_i -> indices(x_i, dim), x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
inds = map(indices, x, dim)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function _pick_range(x, y)
fst = _try_static(known_first(x), known_first(y))
fst = fst === nothing ? first(x) : fst

lst = _try_static(known_last(x), known_last(y))
lst = lst === nothing ? last(x) : lst
fst = _try_static(static_first(x), static_first(y))
lst = _try_static(static_last(x), static_last(y))
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

90 changes: 90 additions & 0 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

"""
A statically sized `Int`.
Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
"""
struct Static{N} <: Integer
Static{N}() where {N} = new{N::Int}()
end
Base.@pure Static(N::Int) = Static{N}()
Static(N::Integer) = Static(convert(Int, N))
Static(::Static{N}) where {N} = Static{N}()
Static(::Val{N}) where {N} = Static{N}()
Base.Val(::Static{N}) where {N} = Val{N}()
Base.convert(::Type{T}, ::Static{N}) where {T<:Number,N} = convert(T, N)
Base.convert(::Type{Static{N}}, ::Static{N}) where {N} = Static{N}()

Base.promote_rule(::Type{<:Static}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T <: AbstractIrrational} = promote_rule(T, Int)
for (S,T) [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)]
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:Static}) where {T <: $T} = promote_rule($S{T}, Int)
end
Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:Static}) = Union{Nothing, Missing, Int}
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Nothing} = promote_rule(T, Int)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Missing} = promote_rule(T, Int)
for T [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any]
# let S = :Any
@eval begin
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: Static} = promote_rule(Int, $T)
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: Static} = promote_rule($T, Int)
end
end
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N

Base.iszero(::Static{0}) = true
Base.iszero(::Static) = false
Base.isone(::Static{1}) = true
Base.isone(::Static) = false

for T = [:Real, :Rational, :Integer]
@eval begin
@inline Base.:(+)(i::$T, ::Static{0}) = i
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
@inline Base.:(+)(::Static{0}, i::$T) = i
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
@inline Base.:(-)(i::$T, ::Static{0}) = i
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
@inline Base.:(*)(i::$T, ::Static{1}) = i
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
@inline Base.:(*)(::Static{1}, i::$T) = i
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
end
end
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()

@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()

@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
end
for f [:(==), :(!=), :(<), :(), :(>), :()]
@eval begin
@inline Base.$f(::Static{M}, ::Static{N}) where {M,N} = $f(M, N)
@inline Base.$f(::Static{M}, x::Int) where {M} = $f(M, x)
@inline Base.$f(x::Int, ::Static{M}) where {M} = $f(x, M)
end
end

@inline function maybe_static(f::F, g::G, x) where {F, G}
L = f(x)
isnothing(L) ? g(x) : Static(L)
end
@inline static_length(x) = maybe_static(known_length, length, x)
@inline static_first(x) = maybe_static(known_first, first, x)
@inline static_last(x) = maybe_static(known_last, last, x)
@inline static_step(x) = maybe_static(known_step, step, x)

58 changes: 50 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using ArrayInterface, Test
using Base: setindex
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, Static
@test ArrayInterface.ismutable(rand(3))

using Aqua
Aqua.test_all(ArrayInterface)

using StaticArrays
x = @SVector [1,2,3]
@test ArrayInterface.ismutable(x) == false
Expand Down Expand Up @@ -220,12 +223,51 @@ end
end

@testset "indices" begin
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
A23 = ones(2,3); SA23 = @SMatrix ones(2,3);
A32 = ones(3,2); SA32 = @SMatrix ones(3,2);
@test @inferred(ArrayInterface.indices((A23, A32))) == 1:6
@test @inferred(ArrayInterface.indices((SA23, A32))) == 1:6
@test @inferred(ArrayInterface.indices((A23, SA32))) == 1:6
@test @inferred(ArrayInterface.indices((SA23, SA32))) == 1:6
@test @inferred(ArrayInterface.indices(A23)) == 1:6
@test @inferred(ArrayInterface.indices(SA23)) == 1:6
@test @inferred(ArrayInterface.indices(A23, 1)) == 1:2
@test @inferred(ArrayInterface.indices(SA23, Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, A32), (1, 2))) == 1:2
@test @inferred(ArrayInterface.indices((SA23, A32), (Static(1), 2))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, SA32), (1, Static(2)))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, SA32), (Static(1), Static(2)))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, A23), 1)) == 1:2
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, A23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1)
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2))
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), Static(1))
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (Static(1), 2))
@test_throws AssertionError ArrayInterface.indices((SA23, SA23), (Static(1), Static(2)))
end

@testset "Static" begin
@test iszero(Static(0))
@test !iszero(Static(1))
# test for ambiguities and correctness
for i [Static(0), Static(1), Static(2), 3]
for j [Static(0), Static(1), Static(2), 3]
i === j === 3 && continue
for f [+, -, *, ÷, %, <<, >>, >>>, &, |, , ==, , ]
(iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error
@test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j))
end
end
i == 3 && break
for f [+, -, *, /, ÷, %, ==, , ]
x = f(convert(Int, i), 1.4)
y = f(1.4, convert(Int, i))
@test convert(typeof(x), @inferred(f(i, 1.4))) === x
@test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === Static(0), returns `NaN`; hence use of ==== in check.
end
end
end

0 comments on commit 85e93de

Please sign in to comment.