Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalpoly for matrix polynomials #1163

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ include("schur.jl")
include("structuredbroadcast.jl")
include("deprecated.jl")

include("evalpoly.jl")

const ⋅ = dot
const × = cross
export ⋅, ×
Expand Down
85 changes: 85 additions & 0 deletions src/evalpoly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# matrix methods for evalpoly(X, p) = ∑ₖ Xᵏ⁻¹ p[k]

# non-inplace fallback for evalpoly(X, p)
function _evalpoly(X::AbstractMatrix, p)
Base.require_one_based_indexing(p)
p0 = isempty(p) ? Base.reduce_empty_iter(+, p) : p[end]
Xone = one(X)
S = Base.promote_op(*, typeof(Xone), typeof(Xone))(Xone) * p0
stevengj marked this conversation as resolved.
Show resolved Hide resolved
for i = length(p)-1:-1:1
S = X * S + @inbounds(p[i] isa AbstractMatrix ? p[i] : p[i] * I)
end
return S
end

_scalarval(x::Number) = x
_scalarval(x::UniformScaling) = x.λ

"""
evalpoly!(Y::AbstractMatrix, X::AbstractMatrix, p)

Evaluate the matrix polynomial ``Y = \\sum_k X^{k-1} p[k]``, storing the result
in-place in `Y`, for the coefficients `p[k]` (a vector or tuple). The coefficients
can be scalars, matrices, or [`UniformScaling`](@ref).

Similar to `evalpoly`, but may be more efficient by working more in-place. (Some
allocations may still be required, however.)
"""
function evalpoly!(Y::AbstractMatrix, X::AbstractMatrix, p::Union{AbstractVector,Tuple})
@boundscheck axes(Y,1) == axes(Y,2) == axes(X,1) == axes(X,2)
Base.require_one_based_indexing(p)

N = length(p)
pN = iszero(N) ? Base.reduce_empty_iter(+, p) : p[N]
if pN isa AbstractMatrix
Y .= pN
elseif N > 1 && p[N-1] isa Union{Number,UniformScaling}
# initialize Y to p[N-1] I + X p[N], in-place
Y .= X .* _scalarval(pN)
for i in axes(Y,1)
@inbounds Y[i,i] += p[N-1] * I
end
N -= 1
else
# initialize Y to one(Y) * pN in-place
for i in axes(Y,1)
for j in axes(Y,2)
@inbounds Y[i,j] = zero(Y[i,j])
end
@inbounds Y[i,i] += one(Y[i,i]) * pN
end

Check warning on line 50 in src/evalpoly.jl

View check run for this annotation

Codecov / codecov/patch

src/evalpoly.jl#L45-L50

Added lines #L45 - L50 were not covered by tests
end
if N > 1
Z = similar(Y) # workspace for mul!
for i = N-1:-1:1
mul!(Z, X, Y)
if p[i] isa AbstractMatrix
Y .= p[i] .+ Z
else
# Y = p[i] * I + Z, in-place
Y .= Z
for j in axes(Y,1)
@inbounds Y[j,j] += p[i] * I
end
end
end
end
return Y
end

# fallback cases: call out-of-place _evalpoly
Base.evalpoly(X::AbstractMatrix, p::Tuple) = _evalpoly(X, p)
Base.evalpoly(X::AbstractMatrix, ::Tuple{}) = zero(one(X)) # dimensionless zero, i.e. 0 * X^0
Base.evalpoly(X::AbstractMatrix, p::AbstractVector) = _evalpoly(X, p)

# optimized in-place cases, limited to types like homogeneous tuples with length > 1
# where we can reliably deduce the output type (= type of X * p[2]),
# and restricted to StridedMatrix (for now) so that we can be more confident that this is a performance win:
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{T, T, Vararg{T}}) where {T<:Union{Number, UniformScaling}} =
evalpoly!(similar(X, Base.promote_op(*, eltype(X), typeof(_scalarval(p[2])))), X, p)
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{AbstractMatrix{T}, AbstractMatrix{T}, Vararg{AbstractMatrix{T}}}) where {T<:Number} =
evalpoly!(similar(X, Base.promote_op(*, eltype(X), T)), X, p)
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:Union{Number, UniformScaling}}) =
length(p) < 2 ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, eltype(X), typeof(_scalarval(p[begin+1])))), X, p)
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:AbstractMatrix{<:Number}}) =
length(p) < 2 ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, eltype(X), eltype(p[begin+1]))), X, p)
33 changes: 33 additions & 0 deletions test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -837,4 +837,37 @@ end
end
end

using LinearAlgebra: _evalpoly # fallback routine, which we'll test explicitly

# naive sum, a little complicated since X^0 fails if eltype(X) is abstract:
naive_evalpoly(X, p) = length(p) == 1 ? one(X) * p[1] : one(X) * p[1] + sum(X^(i-1) * p[i] for i=2:length(p))

@testset "evalpoly" begin
for X in ([1 2 3;4 5 6;7 8 9], UpperTriangular([1 2 3;0 5 6;0 0 9]),
SymTridiagonal([1,2,3],[4,5]), Real[1 2 3;4 5 6;7 8 9])
@test @inferred(evalpoly(X, ())) == zero(X) == evalpoly(X, Int[])
@test @inferred(evalpoly(X, (17,))) == one(X) * 17
@test _evalpoly(X, [1,2,3,4]) == evalpoly(X, [1,2,3,4]) == @inferred(evalpoly(X, (1,2,3,4))) ==
naive_evalpoly(X, [1,2,3,4]) == 1*one(X) + 2*X + 3X^2 + 4X^3
@test typeof(evalpoly(X, [1,2,3])) == typeof(evalpoly(X, (1,2,3))) == typeof(_evalpoly(X, [1,2,3])) ==
typeof(X * X)

# _evalpoly is not type-stable if eltype(X) is abstract
# because one(Real[...]) returns a Matrix{Int}
if isconcretetype(eltype(X))
@inferred evalpoly(X, [1,2,3,4])
@inferred _evalpoly(X, [1,2,3,4])
end

for N in (1,2,4), p in (Real[1,2], rand(-10:10, N), UniformScaling.(rand(-10:10, N)), [rand(-5:5,3,3) for _ = 1:N])
@test _evalpoly(X, p) == evalpoly(X, p) == evalpoly(X, Tuple(p)) == naive_evalpoly(X, p)
end
for N in (1,2,4), p in ((5, 6.7), rand(N), UniformScaling.(rand(N)), [rand(3,3) for _ = 1:N])
@test _evalpoly(X, p) ≈ evalpoly(X, p) ≈ evalpoly(X, Tuple(p)) ≈ naive_evalpoly(X, p)
end

@test_throws MethodError evalpoly(X, [])
end
end

end # module TestGeneric
Loading