Skip to content

Commit 103e9d4

Browse files
authored
Merge pull request #695 from JuliaArrays/cjf/lu-api-updates
Some trait + LU api cleanup
2 parents 85812fe + 0fd60a5 commit 103e9d4

File tree

7 files changed

+43
-12
lines changed

7 files changed

+43
-12
lines changed

src/StaticArrays.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ const StaticMatrixLike{T} = Union{
8686
Adjoint{T, <:StaticVecOrMat{T}},
8787
Symmetric{T, <:StaticMatrix{<:Any, <:Any, T}},
8888
Hermitian{T, <:StaticMatrix{<:Any, <:Any, T}},
89-
Diagonal{T, <:StaticVector{<:Any, T}}
89+
Diagonal{T, <:StaticVector{<:Any, T}},
90+
# We specifically list *Triangular here rather than using
91+
# AbstractTriangular to avoid ambiguities in size() etc.
92+
UpperTriangular{T, <:StaticMatrix{<:Any, <:Any, T}},
93+
LowerTriangular{T, <:StaticMatrix{<:Any, <:Any, T}},
94+
UnitUpperTriangular{T, <:StaticMatrix{<:Any, <:Any, T}},
95+
UnitLowerTriangular{T, <:StaticMatrix{<:Any, <:Any, T}}
9096
}
9197
const StaticVecOrMatLike{T} = Union{StaticVector{<:Any, T}, StaticMatrixLike{T}}
9298
const StaticArrayLike{T} = Union{StaticVecOrMatLike{T}, StaticArray{<:Tuple, T}}

src/abstractarray.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
length(a::SA) where {SA <: StaticArrayLike} = length(SA)
1+
length(a::StaticArrayLike) = prod(Size(a))
22
length(a::Type{SA}) where {SA <: StaticArrayLike} = prod(Size(SA))
33

4-
@pure size(::Type{SA}) where {SA <: StaticArrayLike} = get(Size(SA))
4+
@pure size(::Type{SA}) where {SA <: StaticArrayLike} = Tuple(Size(SA))
55
@inline function size(t::Type{<:StaticArrayLike}, d::Int)
66
S = size(t)
77
d > length(S) ? 1 : S[d]
88
end
9-
@inline size(a::StaticArrayLike) = size(typeof(a))
10-
@inline size(a::StaticArrayLike, d::Int) = size(typeof(a), d)
9+
@inline size(a::StaticArrayLike) = Tuple(Size(a))
1110

1211
Base.axes(s::StaticArray) = _axes(Size(s))
1312
@pure function _axes(::Size{sizes}) where {sizes}

src/lu.jl

+18-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@ Base.iterate(S::LU, ::Val{:U}) = (S.U, Val(:p))
1111
Base.iterate(S::LU, ::Val{:p}) = (S.p, Val(:done))
1212
Base.iterate(S::LU, ::Val{:done}) = nothing
1313

14+
@inline function Base.getproperty(F::LU, s::Symbol)
15+
if s === :P
16+
U = getfield(F, :U)
17+
p = getfield(F, :p)
18+
one(similar_type(p, Size(U)))[:,invperm(p)]
19+
else
20+
getfield(F, s)
21+
end
22+
end
23+
24+
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
25+
println(io, LU) # Don't show full type - this will be in the factors
26+
println(io, "L factor:")
27+
show(io, mime, F.L)
28+
println(io, "\nU factor:")
29+
show(io, mime, F.U)
30+
end
31+
1432
# LU decomposition
1533
function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true))
1634
L, U, p = _lu(A, pivot)
@@ -136,9 +154,6 @@ end
136154
:(SVector{$(M-1),Int}($(tuple(2:M...))))
137155
end
138156

139-
# Base.lufact() interface is fairly inherently type unstable. Punt on
140-
# implementing that, for now...
141-
142157
\(F::LU, v::AbstractVector) = F.U \ (F.L \ v[F.p])
143158
\(F::LU, B::AbstractMatrix) = F.U \ (F.L \ B[F.p,:])
144159

src/traits.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ Size(::Type{Transpose{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A)
9393
Size(::Type{Symmetric{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A)
9494
Size(::Type{Hermitian{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A)
9595
Size(::Type{Diagonal{T, A}}) where {T, A <: AbstractVector{T}} = Size(Size(A)[1], Size(A)[1])
96+
Size(::Type{<:LinearAlgebra.AbstractTriangular{T, A}}) where {T,A} = Size(A)
9697

9798
@pure Size(::Type{<:AbstractArray{<:Any, N}}) where {N} = Size(ntuple(_ -> Dynamic(), N))
9899

@@ -117,7 +118,7 @@ Length(::Size{S}) where {S} = _Length(S...)
117118
@inline _Length(S...) = Length{Dynamic()}()
118119

119120
# Some @pure convenience functions for `Size`
120-
@pure get(::Size{S}) where {S} = S
121+
@pure (::Type{Tuple})(::Size{S}) where {S} = S
121122

122123
@pure getindex(::Size{S}, i::Int) where {S} = i <= length(S) ? S[i] : 1
123124

@@ -138,7 +139,7 @@ Base.LinearIndices(::Size{S}) where {S} = LinearIndices(S)
138139
@pure size_tuple(::Size{S}) where {S} = Tuple{S...}
139140

140141
# Some @pure convenience functions for `Length`
141-
@pure get(::Length{L}) where {L} = L
142+
@pure (::Type{Int})(::Length{L}) where {L} = L
142143

143144
@pure Base.:(==)(::Length{L}, l::Int) where {L} = L == l
144145
@pure Base.:(==)(l::Int, ::Length{L}) where {L} = l == L

src/triangular.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
@inline Size(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = Size(A.data)
2-
31
@inline transpose(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) =
42
LinearAlgebra.UpperTriangular(transpose(A.data))
53
@inline adjoint(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) =

test/core.jl

+3
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@
155155
@test Size(Adjoint(zero(SMatrix{2, 3}))) == Size(3, 2)
156156
@test Size(Diagonal(SVector(1, 2, 3))) == Size(3, 3)
157157
@test Size(Transpose(Diagonal(SVector(1, 2, 3)))) == Size(3, 3)
158+
@test Size(UpperTriangular(zero(SMatrix{2, 2}))) == Size(2, 2)
159+
@test Size(LowerTriangular(zero(SMatrix{2, 2}))) == Size(2, 2)
160+
@test Size(LowerTriangular(Symmetric(zero(SMatrix{2, 2})))) == Size(2,2)
158161
end
159162

160163
@testset "dimmatch" begin

test/lu.jl

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
using StaticArrays, Test, LinearAlgebra
22

3+
@testset "LU utils" begin
4+
F = lu(SA[1 2; 3 4])
5+
6+
@test @inferred((F->F.p)(F)) === SA[2, 1]
7+
@test @inferred((F->F.P)(F)) === SA[0 1; 1 0]
8+
9+
@test occursin(r"^StaticArrays.LU.*L factor.*U factor"s, sprint(show, MIME("text/plain"), F))
10+
end
11+
312
@testset "LU decomposition ($m×$n, pivot=$pivot)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15]
413
a = SMatrix{m,n,Int}(1:(m*n))
514
l, u, p = @inferred(lu(a, Val{pivot}()))

0 commit comments

Comments
 (0)