diff --git a/docs/src/interface.md b/docs/src/interface.md index c7e757b7..f433804c 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -55,3 +55,26 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth affine as well, and all sorts of things. Thus affine operators have no matrix representation but they are still compatible with essentially any Krylov method which would otherwise be compatible with matrix-free representations, hence their support in the SciMLOperators interface. + +## Note about keyword arguments to `update_coefficients!` + +In rare cases, an operator may be used in a context where additional state is expected to be provided +to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional +state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator. +For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument (by default, none are passed). + +In the below example, we create an operator that gleefully ignores `u`, `p`, and `t` and uses its own special scaling. +```@example +using SciMLOperators + +γ = ScalarOperator(0.0; update_func=(a, u, p, t; my_special_scaling) -> my_special_scaling, + accepted_kwargs=(:my_special_scaling,)) + +# Update coefficients, then apply operator +update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling=7.0) +@show γ * [2.0] + +# Use operator application form +@show γ([2.0], nothing, nothing; my_special_scaling = 5.0) +nothing # hide +``` \ No newline at end of file diff --git a/src/batch.jl b/src/batch.jl index aea688cd..ec752bcf 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -1,12 +1,12 @@ # """ - BatchedDiagonalOperator(diag, [; update_func]) + BatchedDiagonalOperator(diag; update_func, update_func!, accepted_kwargs) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. Acts on `AbstractArray`s of the same size as `diag`. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(diag::AbstractVector,u,p,t) -> [modifies diag] + update_func(diag::AbstractArray, u, p, t; ) -> [modifies diag] """ struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T} diag::D @@ -14,6 +14,7 @@ struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T} update_func!::F! function BatchedDiagonalOperator(diag::AbstractArray, update_func, update_func!) + new{ eltype(diag), typeof(diag), @@ -25,15 +26,16 @@ struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T} end end -function BatchedDiagonalOperator(diag::AbstractArray; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) - BatchedDiagonalOperator(diag, update_func, update_func!) -end +function DiagonalOperator(u::AbstractArray; + update_func = DEFAULT_UPDATE_FUNC, + update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing + ) + + update_func = preprocess_update_func(update_func , accepted_kwargs) + update_func! = preprocess_update_func(update_func!, accepted_kwargs) -function DiagonalOperator(u::AbstractArray; update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) - BatchedDiagonalOperator(u; update_func = update_func, update_func! = update_func!) + BatchedDiagonalOperator(u, update_func, update_func!) end # traits @@ -46,38 +48,39 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly update_func = if isreal(L) L.update_func else - (L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t)) + (L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...)) end BatchedDiagonalOperator(diag; update_func=update_func) end -function update_coefficients(L::BatchedDiagonalOperator,u,p,t) - @set! L.diag = L.update_func(L.diag,u,p,t) +LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true +function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator) + if isreal(L) + true + else + vec(L.diag) |> Diagonal |> ishermitian + end +end +LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag))) + +function update_coefficients(L::BatchedDiagonalOperator,u ,p, t; kwargs...) + @set! L.diag = L.update_func(L.diag, u, p, t; kwargs...) +end + +function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...) + L.update_func!(L.diag, u, p, t; kwargs...) end -update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func!(L.diag,u,p,t); L) getops(L::BatchedDiagonalOperator) = (L.diag,) function isconstant(L::BatchedDiagonalOperator) - L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC + update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!) end islinear(::BatchedDiagonalOperator) = true has_adjoint(L::BatchedDiagonalOperator) = true has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag) has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L) -LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true -function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator) - if isreal(L) - true - else - d = vec(L.diag) - D = Diagonal(d) - ishermitian(d) - end -end -LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag))) - # operator application Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u diff --git a/src/func.jl b/src/func.jl index 2e92b64f..58145224 100644 --- a/src/func.jl +++ b/src/func.jl @@ -84,6 +84,7 @@ function FunctionOperator(op, FunctionOperator(op, input, output; kwargs...) end +# TODO: document constructor and revisit design as needed (e.g. for "accepted_kwargs") function FunctionOperator(op, input::AbstractVecOrMat, output::AbstractVecOrMat = input; @@ -101,6 +102,7 @@ function FunctionOperator(op, p=nothing, t::Union{Number,Nothing}=nothing, + accepted_kwargs::NTuple{N,Symbol} = (), ifcache::Bool = true, @@ -111,7 +113,7 @@ function FunctionOperator(op, issymmetric::Bool = false, ishermitian::Bool = false, isposdef::Bool = false, - ) + ) where{N} # store eltype of input/output for caching with ComposedOperator. eltypes = eltype.((input, output)) @@ -181,6 +183,8 @@ function FunctionOperator(op, T = T, size = sz, eltypes = eltypes, + accepted_kwargs = accepted_kwargs, + kwargs = Dict{Symbol, Any}(), ) L = FunctionOperator( @@ -191,7 +195,7 @@ function FunctionOperator(op, traits, p, t, - cache, + cache ) if ifcache & isnothing(L.cache) @@ -201,36 +205,40 @@ function FunctionOperator(op, L end -function update_coefficients(L::FunctionOperator, u, p, t) - - if isconstant(L) - return L - end - - @set! L.op = update_coefficients(L.op, u, p, t) - @set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t) - @set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t) - @set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t) +function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) + # update p, t @set! L.p = p @set! L.t = t - L -end + # filter and update kwargs + filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) + @set! L.traits.kwargs = Dict{Symbol, Any}(filtered_kwargs) -function update_coefficients!(L::FunctionOperator, u, p, t) + isconstant(L) && return L - if isconstant(L) - return L - end + @set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...) + @set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...) + @set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...) + @set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...) +end - for op in getops(L) - update_coefficients!(op, u, p, t) - end +function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) + # update p, t L.p = p L.t = t + # filter and update kwargs + filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) + L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs)) + + isconstant(L) && return + + for op in getops(L) + update_coefficients!(op, u, p, t; filtered_kwargs...) + end + L end @@ -267,9 +275,6 @@ function Base.adjoint(L::FunctionOperator) @set! traits.size = reverse(size(L)) @set! traits.eltypes = reverse(traits.eltypes) - p = L.p - t = L.t - cache = if iscached(L) cache = reverse(L.cache) else @@ -281,8 +286,8 @@ function Base.adjoint(L::FunctionOperator) op_inverse, op_adjoint_inverse, traits, - p, - t, + L.p, + L.t, cache, ) end @@ -310,9 +315,6 @@ function Base.inv(L::FunctionOperator) (p::Real) -> 1 / traits.opnorm(p) end - p = L.p - t = L.t - cache = if iscached(L) cache = reverse(L.cache) else @@ -324,8 +326,8 @@ function Base.inv(L::FunctionOperator) op_inverse, op_adjoint_inverse, traits, - p, - t, + L.p, + L.t, cache, ) end @@ -353,8 +355,8 @@ function LinearAlgebra.opnorm(L::FunctionOperator, p) argument. E.g., `(p::Real) -> p == Inf ? 100 : error("only Inf norm is defined")` """) - opn = L.opnorm - return opn isa Number ? opn : L.opnorm(p) + opn = L.traits.opnorm + return opn isa Number ? opn : L.traits.opnorm(p) end LinearAlgebra.issymmetric(L::FunctionOperator) = L.traits.issymmetric LinearAlgebra.ishermitian(L::FunctionOperator) = L.traits.ishermitian @@ -373,31 +375,36 @@ end islinear(L::FunctionOperator) = L.traits.islinear isconstant(L::FunctionOperator) = L.traits.isconstant has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing) -has_mul(L::FunctionOperator{iip}) where{iip} = true -has_mul!(L::FunctionOperator{iip}) where{iip} = iip +has_mul(::FunctionOperator{iip}) where{iip} = true +has_mul!(::FunctionOperator{iip}) where{iip} = iip has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing) has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing) # TODO - FunctionOperator, Base.conj, transpose # operator application -Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t) -Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t) +function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} + L.op(u, L.p, L.t; L.traits.kwargs...) +end + +function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} + L.op_inverse(u, L.p, L.t; L.traits.kwargs...) +end function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat) _, co = L.cache du = zero(co) - L.op(du, u, L.p, L.t) + L.op(du, u, L.p, L.t; L.traits.kwargs...) end function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat) ci, _ = L.cache du = zero(ci) - L.op_inverse(du, u, L.p, L.t) + L.op_inverse(du, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op(v, u, L.p, L.t) + L.op(v, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...) @@ -414,11 +421,11 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop} - L.op(v, u, L.p, L.t, α, β) + L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...) end function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op_inverse(v, u, L.p, L.t) + L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat) diff --git a/src/interface.jl b/src/interface.jl index 655bc33e..f050692d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -3,23 +3,56 @@ # Operator interface ### +""" +Function call and multiplication: + - L(du, u, p, t) for in-place operator evaluation, + - du = L(u, p, t) for out-of-place operator evaluation + +If the operator is not a constant, update it with (u,p,t). A mutating form, i.e. +update_coefficients!(A,u,p,t) that changes the internal coefficients, and a +out-of-place form B = update_coefficients(A,u,p,t). + +""" +function (::AbstractSciMLOperator) end + +### +# Utilities for update functions +### + DEFAULT_UPDATE_FUNC(A,u,p,t) = A -update_coefficients(L,u,p,t) = L -update_coefficients!(L,u,p,t) = L +struct NoKwargFilter end + +function preprocess_update_func(update_func, accepted_kwargs) + _update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func + _accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs + # accepted_kwargs can be passed as nothing to indicate that we should not filter + # (e.g. if the function already accepts all kwargs...). + return (_accepted_kwargs isa NoKwargFilter) ? _update_func : FilterKwargs(_update_func, _accepted_kwargs) +end +function update_func_isconstant(update_func) + if update_func isa FilterKwargs + return update_func.f == DEFAULT_UPDATE_FUNC + else + return update_func == DEFAULT_UPDATE_FUNC + end +end + +update_coefficients!(L,u,p,t; kwargs...) = nothing +update_coefficients(L,u,p,t; kwargs...) = L -function update_coefficients!(L::AbstractSciMLOperator, u, p, t) +function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...) for op in getops(L) - update_coefficients!(op, u, p, t) + update_coefficients!(op, u, p, t; kwargs...) end L end -(L::AbstractSciMLOperator)(u, p, t) = update_coefficients(L, u, p, t) * u -(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u)) -(L::AbstractSciMLOperator)(du, u, p, t, α, β) = (update_coefficients!(L, u, p, t); mul!(du, L, u, α, β)) +(L::AbstractSciMLOperator)(u, p, t; kwargs...) = update_coefficients(L, u, p, t; kwargs...) * u +(L::AbstractSciMLOperator)(du, u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u)) +(L::AbstractSciMLOperator)(du, u, p, t, α, β; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u, α, β)) -function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...) +function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs...) msg = """Nonallocating L(v, u, p, t) type methods are not available for subtypes of `Number`.""" throw(ArgumentError(msg)) diff --git a/src/matrix.jl b/src/matrix.jl index 2ef44763..15135686 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -1,12 +1,12 @@ # """ - MatrixOperator(A[; update_func]) + MatrixOperator(A; [update_func, update_func!, accepted_kwargs]) Represents a time-dependent linear operator given by an AbstractMatrix. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(A::AbstractMatrix,u,p,t) -> [modifies A] + update_func(A::AbstractMatrix,u,p,t; ) -> [modifies A] """ struct MatrixOperator{T,AT<:AbstractMatrix{T},F,F!} <: AbstractSciMLOperator{T} A::AT @@ -14,19 +14,26 @@ struct MatrixOperator{T,AT<:AbstractMatrix{T},F,F!} <: AbstractSciMLOperator{T} update_func!::F! function MatrixOperator(A, update_func, update_func!) + new{ eltype(A), typeof(A), typeof(update_func), typeof(update_func!), }( - A, update_func, update_func! + A, update_func, update_func!, ) end end -function MatrixOperator(A; update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) +function MatrixOperator(A; + update_func = DEFAULT_UPDATE_FUNC, + update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing,) + + update_func = preprocess_update_func(update_func , accepted_kwargs) + update_func! = preprocess_update_func(update_func!, accepted_kwargs) + MatrixOperator(A, update_func, update_func!) end @@ -48,42 +55,49 @@ end islinear(::MatrixOperator) = true Base.size(L::MatrixOperator) = size(L.A) +Base.iszero(L::MatrixOperator) = iszero(L.A) for op in ( :adjoint, :transpose, ) @eval function Base.$op(L::MatrixOperator) - if isconstant(L) - MatrixOperator($op(L.A)) - else - update_func = (A,u,p,t) -> $op(L.update_func($op(L.A), u, p, t)) - update_func! = (A,u,p,t) -> $op(L.update_func!($op(L.A), u, p, t)) - MatrixOperator($op(L.A); update_func = update_func, - update_func! = update_func!) - end + isconstant(L) && return MatrixOperator($op(L.A)) + + update_func = (A, u, p, t; kwargs...) -> $op(L.update_func( $op(L.A), u, p, t; kwargs...)) + update_func! = (A, u, p, t; kwargs...) -> $op(L.update_func!($op(L.A), u, p, t; kwargs...)) + + MatrixOperator($op(L.A); + update_func = update_func, + update_func! = update_func!, + accepted_kwargs = NoKwargFilter(), + ) end end function Base.conj(L::MatrixOperator) - update_func = (A, u, p, t) -> conj(L.update_func(conj(L.A), u, p, t)) - update_func! = (A, u, p, t) -> conj(L.update_func!(conj(L.A), u, p, t)) + isconstant(L) && return MatrixOperator(conj(L.A)) - MatrixOperator(conj(L.A); update_func = update_func, - update_func! = update_func!) + update_func = (A, u, p, t; kwargs...) -> conj(L.update_func(conj(L.A), u, p, t; kwargs...)) + update_func! = (A, u, p, t; kwargs...) -> conj(L.update_func!(conj(L.A), u, p, t; kwargs...)) + + MatrixOperator(conj(L.A); + update_func = update_func, + update_func! = update_func!, + accepted_kwargs = NoKwargFilter(), + ) end has_adjoint(A::MatrixOperator) = has_adjoint(A.A) +getops(L::MatrixOperator) = (L.A,) +isconstant(L::MatrixOperator) = update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!) -function update_coefficients(L::MatrixOperator, u, p, t) - @set! L.A = L.update_func(L.A, u, p, t) +function update_coefficients(L::MatrixOperator, u, p, t; kwargs...) + @set! L.A = L.update_func(L.A, u, p, t; kwargs...) end -update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func!(L.A, u, p, t); L) -getops(L::MatrixOperator) = (L.A,) -function isconstant(L::MatrixOperator) - L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC +function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...) + L.update_func!(L.A, u, p, t; kwargs...) end -Base.iszero(L::MatrixOperator) = iszero(L.A) SparseArrays.sparse(L::MatrixOperator) = sparse(L.A) SparseArrays.issparse(L::MatrixOperator) = issparse(L.A) @@ -106,7 +120,7 @@ Base.copyto!(L::MatrixOperator, rhs::Base.Broadcast.Broadcasted{<:StaticArraysCo Base.Broadcast.broadcastable(L::MatrixOperator) = L Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType) ArrayInterface.issingular(L::MatrixOperator) = ArrayInterface.issingular(L.A) -Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func) +Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func, accepted_kwargs=NoKwargFilter()) # operator application Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u @@ -117,13 +131,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u) """ - DiagonalOperator(diag, [; update_func]) + DiagonalOperator(diag; [update_func, update_func!, accepted_kwargs]) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(diag::AbstractVector,u,p,t) -> [modifies diag] + update_func(diag::AbstractVector,u,p,t; ) -> [modifies diag] When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)` can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u` @@ -134,23 +148,23 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)` with leading length `size(u, 1) = N`. """ -function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) - - diag_update_func = if update_func == DEFAULT_UPDATE_FUNC - DEFAULT_UPDATE_FUNC - else - (A, u, p, t) -> (d = update_func(A.diag, u, p, t); Diagonal(d)) - end - - diag_update_func! = if update_func! == DEFAULT_UPDATE_FUNC - DEFAULT_UPDATE_FUNC - else - (A, u, p, t) -> (update_func!(A.diag, u, p, t); A) - end - - MatrixOperator(Diagonal(diag); update_func = diag_update_func, - update_func! = diag_update_func!) +function DiagonalOperator(diag::AbstractVector; + update_func = DEFAULT_UPDATE_FUNC, + update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing, + ) + + diag_update_func = update_func_isconstant(update_func) ? update_func : + (A, u, p, t; kwargs...) -> update_func(A.diag, u, p, t; kwargs...) |> Diagonal + + diag_update_func! = update_func_isconstant(update_func!) ? update_func! : + (A, u, p, t; kwargs...) -> update_func!(A.diag, u, p, t; kwargs...) + + MatrixOperator(Diagonal(diag); + update_func = diag_update_func, + update_func! = diag_update_func!, + accepted_kwargs = accepted_kwargs, + ) end LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A)) @@ -255,13 +269,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) """ - L = AffineOperator(A, B, b[; update_func]) + L = AffineOperator(A, B, b; [update_func, update_func!, accepted_kwargs]) L(u) = A*u + B*b Represents a time-dependent affine operator. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(b::AbstractArray,u,p,t) -> [modifies b] + update_func(b::AbstractArray,u,p,t; ) -> [modifies b] """ struct AffineOperator{T,AT,BT,bT,C,F,F!} <: AbstractSciMLOperator{T} A::AT @@ -293,56 +307,85 @@ function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, b::AbstractArray; update_func = DEFAULT_UPDATE_FUNC, update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing, ) + @assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors of same size" + update_func = preprocess_update_func(update_func , accepted_kwargs) + update_func! = preprocess_update_func(update_func!, accepted_kwargs) + A = A isa AbstractMatrix ? MatrixOperator(A) : A B = B isa AbstractMatrix ? MatrixOperator(B) : B + cache = B * b AffineOperator(A, B, b, cache, update_func, update_func!) end """ - L = AddVector(b[; update_func]) + L = AddVector(b; [update_func, update_func!, accepted_kwargs]) L(u) = u + b """ -function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) +function AddVector(b::AbstractVecOrMat; + update_func = DEFAULT_UPDATE_FUNC, + update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing + ) + N = size(b, 1) Id = IdentityOperator(N) - AffineOperator(Id, Id, b; update_func = update_func, update_func! = update_func!) + AffineOperator(Id, Id, b; + update_func = update_func, + update_func! = update_func!, + accepted_kwargs = accepted_kwargs, + ) end """ - L = AddVector(B, b[; update_func]) + L = AddVector(B, b; [update_func, accepted_kwargs]) L(u) = u + B*b """ -function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC) +function AddVector(B, b::AbstractVecOrMat; + update_func = DEFAULT_UPDATE_FUNC, + update_func! = DEFAULT_UPDATE_FUNC, + accepted_kwargs=nothing + ) + N = size(B, 1) Id = IdentityOperator(N) - AffineOperator(Id, B, b; update_func = update_func, update_func! = update_func!) + AffineOperator(Id, B, b; + update_func = update_func, + update_func! = update_func!, + accepted_kwargs = accepted_kwargs, + ) end -function update_coefficients(L::AffineOperator, u, p, t) - @set! L.A = update_coefficients(L.A, u, p, t) - @set! L.B = update_coefficients(L.B, u, p, t) - @set! L.b = L.update_func(L.b, u, p, t) - - L +function update_coefficients(L::AffineOperator, u, p, t; kwargs...) + @set! L.A = update_coefficients(L.A, u, p, t; kwargs...) + @set! L.B = update_coefficients(L.B, u, p, t; kwargs...) + @set! L.b = L.update_func(L.b, u, p, t; kwargs...) end -update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func!(L.b,u,p,t); L) +function update_coefficients!(L::AffineOperator, u, p, t; kwargs...) + L.update_func!(L.b, u, p, t; kwargs...) + for op in getops(L) + update_coefficients!(op, u, p, t; kwargs...) + end + nothing +end -getops(L::AffineOperator) = (L.A, L.B, L.b) function isconstant(L::AffineOperator) - (L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC) & + update_func_isconstant(L.update_func) & + update_func_isconstant(L.update_func!) & all(isconstant, (L.A, L.B)) end + +getops(L::AffineOperator) = (L.A, L.B, L.b) + islinear(::AffineOperator) = false Base.size(L::AffineOperator) = size(L.A) @@ -356,7 +399,7 @@ function Base.resize!(L::AffineOperator, n::Integer) L end -has_adjoint(L::AffineOperator) = all(has_adjoint, L.ops) +has_adjoint(L::AffineOperator) = false has_mul(L::AffineOperator) = has_mul(L.A) has_mul!(L::AffineOperator) = has_mul!(L.A) has_ldiv(L::AffineOperator) = has_ldiv(L.A) diff --git a/src/scalar.jl b/src/scalar.jl index 45ec1ab3..41ed9325 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -3,8 +3,8 @@ # AbstractSciMLScalarOperator interface ### -function (L::AbstractSciMLScalarOperator)(u::Number, p, t) - L = update_coefficients(L, u, p, t) +function (L::AbstractSciMLScalarOperator)(u::Number, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) convert(Number, L) * u end @@ -96,7 +96,7 @@ end Base.:+(α::AbstractSciMLScalarOperator) = α """ - ScalarOperator(val[; update_func]) + ScalarOperator(val; [update_func, accepted_kwargs]) (α::ScalarOperator)(a::Number) = α * a @@ -104,14 +104,19 @@ Represents a time-dependent scalar/scaling operator. The update function is called by `update_coefficients`/ `update_coefficients!` and is assumed to have the following signature: - update_func(oldval,u,p,t) -> newval + update_func(oldval,u,p,t; ) -> newval """ mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T} val::T update_func::F end -function ScalarOperator(val::T; update_func=DEFAULT_UPDATE_FUNC) where{T} +function ScalarOperator(val; + update_func = DEFAULT_UPDATE_FUNC, + accepted_kwargs = nothing, + ) + + update_func = preprocess_update_func(update_func, accepted_kwargs) ScalarOperator(val, update_func) end @@ -125,8 +130,8 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) # traits function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) - update_func = (oldval,u,p,t) -> α.update_func(oldval |> conj,u,p,t) |> conj - ScalarOperator(val; update_func=update_func) + update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj + ScalarOperator(val; update_func=update_func, accepted_kwargs=NoKwargFilter()) end Base.one(::AbstractSciMLScalarOperator{T}) where{T} = ScalarOperator(one(T)) @@ -139,12 +144,17 @@ Base.abs(α::ScalarOperator) = abs(α.val) Base.iszero(α::ScalarOperator) = iszero(α.val) getops(α::ScalarOperator) = (α.val,) -isconstant(α::ScalarOperator) = α.update_func == DEFAULT_UPDATE_FUNC +isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func) has_ldiv(α::ScalarOperator) = !iszero(α.val) has_ldiv!(α::ScalarOperator) = has_ldiv(α) -update_coefficients(L::ScalarOperator, u, p, t) = @set! L.val = L.update_func(L.val, u, p, t) -update_coefficients!(L::ScalarOperator, u, p, t) = (L.val = L.update_func(L.val,u,p,t); L) +function update_coefficients!(L::ScalarOperator,u,p,t; kwargs...) + L.val = L.update_func(L.val, u, p, t; kwargs...) +end + +function update_coefficients(L::ScalarOperator, u, p, t; kwargs...) + @set! L.val = L.update_func(L.val, u, p, t; kwargs...) +end """ Lazy addition of Scalar Operators diff --git a/src/utils.jl b/src/utils.jl index 213de5ee..b08c18a5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,4 +12,22 @@ end dims(A) = length(size(A)) dims(::AbstractArray{<:Any,N}) where{N} = N dims(::AbstractSciMLOperator) = 2 + +# Keyword argument filtering +struct FilterKwargs{F,K} + f::F + accepted_kwargs::K +end + +# Filter keyword arguments to those accepted by function. +# Avoid throwing errors here if a keyword argument is not provided: defer +# this to the function call for a more readable error. +function get_filtered_kwargs(kwargs::AbstractDict, accepted_kwargs::NTuple{N,Symbol}) where{N} + (kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw)) +end + +function (f::FilterKwargs)(args...; kwargs...) + filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs) + f.f(args...; filtered_kwargs...) +end # diff --git a/test/func.jl b/test/func.jl index 89933347..fc56137b 100644 --- a/test/func.jl +++ b/test/func.jl @@ -104,19 +104,22 @@ end u = rand(N,K) p = rand(N) t = rand() + scale = rand() - f(u, p, t) = Diagonal(p * t) * u - f(du,u,p,t) = mul!(du, Diagonal(p*t), u) + # Accept a kwarg "scale" in operator action + f(du,u,p,t; scale = 1.0) = mul!(du, Diagonal(p*t*scale), u) + f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u - L = FunctionOperator(f, u, u; p=zero(p), t=zero(t)) + L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), + accepted_kwargs = (:scale,)) - ans = @. u * p * t - @test L(u,p,t) ≈ ans - v=copy(u); @test L(v,u,p,t) ≈ ans + ans = @. u * p * t * scale + @test L(u,p,t; scale) ≈ ans + v=copy(u); @test L(v,u,p,t; scale) ≈ ans # test that output isn't accidentally mutated by passing an internal cache. - A = Diagonal(p * t) + A = Diagonal(p * t * scale) u1 = rand(N, K) u2 = rand(N, K) diff --git a/test/scalar.jl b/test/scalar.jl index 977857c2..da3abaf3 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -148,9 +148,17 @@ end @test convert(Number, num) ≈ val - @test num(u, p, t) ≈ val * u - v=rand(N,K); @test num(v, u, p, t) ≈ val * u - v=rand(N,K); w=copy(v); @test num(v, u, p, t, a, b) ≈ a*val*u + b*w - + # Test scalar operator which expects keyword argument to update, + # modeled in the style of a DiffEq W-operator. + γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma, + accepted_kwargs = (:dtgamma,)) + + dtgamma = rand() + @test γ(u,p,t; dtgamma) ≈ dtgamma * u + @test γ(v,u,p,t; dtgamma) ≈ dtgamma * u + + γ_added = γ + α + @test γ_added(u,p,t; dtgamma) ≈ (dtgamma + p) * u + @test γ_added(v,u,p,t; dtgamma) ≈ (dtgamma + p) * u end # diff --git a/test/total.jl b/test/total.jl index 38289916..5a1e9004 100644 --- a/test/total.jl +++ b/test/total.jl @@ -79,20 +79,29 @@ end @testset "Operator Algebra" begin N2 = N*N + A = rand(N,N) - B = rand(N,N) + # Introduce update function for B + B = MatrixOperator(zeros(N,N); update_func! = (A, u, p, t) -> (A .= p)) C = rand(N,N) - D = rand(N,N) + # Introduce update function for D dependent on kwarg "matrix" + D = MatrixOperator(zeros(N,N); update_func! = (A, u, p, t; matrix) -> (A .= p*t*matrix), + accepted_kwargs = (:matrix,)) u = rand(N2,K) + p = rand() + t = rand() + matrix = rand(N, N) + diag = rand(N2) α = rand() β = rand() T1 = ⊗(A, B) T2 = ⊗(C, D) - D1 = DiagonalOperator(rand(N2)) - D2 = DiagonalOperator(rand(N2)) + D1 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t) -> d .= p) + D2 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t; diag) -> d .= p*t*diag, + accepted_kwargs = (:diag,)) TT = [T1, T2] DD = Diagonal([D1, D2]) @@ -100,8 +109,23 @@ end op = TT' * DD * TT op = cache_operator(op, u) + # Update operator + @test_nowarn update_coefficients!(op, u, p, t; diag, matrix) + # Form dense operator manually + dense_T1 = kron(A, p * ones(N, N)) + dense_T2 = kron(C, (p*t) .* matrix) + dense_DD = Diagonal(vcat(p * ones(N2), p*t*diag)) + dense_op = hcat(dense_T1', dense_T2') * dense_DD * vcat(dense_T1, dense_T2) + # Test correctness of op + @test op * u ≈ dense_op * u + @test convert(AbstractMatrix, op) ≈ dense_op + # Test consistency with three-arg mul! v=rand(N2,K); @test mul!(v, op, u) ≈ op * u + # Test consistency with in-place five-arg mul! v=rand(N2,K); w=copy(v); @test mul!(v, op, u, α, β) ≈ α*(op * u) + β * w + # Test consistency with operator application form + @test op(u, p, t; diag, matrix) ≈ op * u + v=rand(N2,K); @test op(v, u, p, t; diag, matrix) ≈ op * u end @testset "Resize! test" begin