From 1ed74b62ac9c209c094963f7e5f49cd7eab477d1 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 31 Jan 2023 05:02:54 +0800 Subject: [PATCH 01/23] Recursively propagate kwargs through update_coefficients! --- docs/src/interface.md | 7 +++++ src/batch.jl | 22 ++++++++------- src/func.jl | 4 +-- src/interface.jl | 24 ++++++++++++----- src/matrix.jl | 63 +++++++++++++++++++++++++------------------ src/scalar.jl | 16 ++++++----- src/utils.jl | 10 +++++++ test/scalar.jl | 11 ++++++++ 8 files changed, 106 insertions(+), 51 deletions(-) diff --git a/docs/src/interface.md b/docs/src/interface.md index c7e757b7..001175e8 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -55,3 +55,10 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth affine as well, and all sorts of things. Thus affine operators have no matrix representation but they are still compatible with essentially any Krylov method which would otherwise be compatible with matrix-free representations, hence their support in the SciMLOperators interface. + +## Note about keyword arguments to `update_coefficients!` + +In rare cases, an operator may be used in a context where additional state is expected to be provided +to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional +state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator. +For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwarg_fields` argument that defaults to an empty tuple. diff --git a/src/batch.jl b/src/batch.jl index c8504e7f..680d2226 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -1,12 +1,12 @@ # """ - BatchedDiagonalOperator(diag, [; update_func]) + BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=()) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. Acts on `AbstractArray`s of the same size as `diag`. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(diag::AbstractVector,u,p,t) -> [modifies diag] + update_func(diag::AbstractVector,u,p,t; ) -> [modifies diag] """ struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T} diag::D @@ -14,20 +14,22 @@ struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T} function BatchedDiagonalOperator( diag::AbstractArray; - update_func=DEFAULT_UPDATE_FUNC + update_func=nothing, + accepted_kwarg_fields=() ) + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) new{ eltype(diag), typeof(diag), - typeof(update_func) + typeof(_update_func) }( - diag, update_func, + diag, _update_func, ) end end -function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC) - BatchedDiagonalOperator(u; update_func=update_func) +function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwarg_fields=()) + BatchedDiagonalOperator(u; update_func, accepted_kwarg_fields) end # traits @@ -40,7 +42,7 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly update_func = if isreal(L) L.update_func else - (L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t)) + (L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...)) end BatchedDiagonalOperator(diag; update_func=update_func) end @@ -57,7 +59,7 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator) end LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag))) -isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC +isconstant(L::BatchedDiagonalOperator) = update_func_isconstant(L.update_func) islinear(::BatchedDiagonalOperator) = true has_adjoint(L::BatchedDiagonalOperator) = true has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag) @@ -65,7 +67,7 @@ has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L) getops(L::BatchedDiagonalOperator) = (L.diag,) -update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing) +update_coefficients!(L::BatchedDiagonalOperator,u,p,t; kwargs...) = (L.update_func(L.diag,u,p,t; kwargs...); nothing) # operator application Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u diff --git a/src/func.jl b/src/func.jl index 91faff49..9a08bfbc 100644 --- a/src/func.jl +++ b/src/func.jl @@ -192,10 +192,10 @@ function update_coefficients(L::FunctionOperator, u, p, t) ) end -function update_coefficients!(L::FunctionOperator, u, p, t) +function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) ops = getops(L) for op in ops - update_coefficients!(op, u, p, t) + update_coefficients!(op, u, p, t; kwargs...) end L.p = p diff --git a/src/interface.jl b/src/interface.jl index 606aa725..73d6dd9a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -15,19 +15,31 @@ out-of-place form B = update_coefficients(A,u,p,t). """ function (::AbstractSciMLOperator) end +# Utilities for update functions DEFAULT_UPDATE_FUNC(A,u,p,t) = A +function preprocess_update_func(update_func, accepted_kwarg_fields) + update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func + return FilterKwargs(update_func, accepted_kwarg_fields) +end +function update_func_isconstant(update_func) + if update_func isa FilterKwargs + return update_func.f == DEFAULT_UPDATE_FUNC + else + return update_func == DEFAULT_UPDATE_FUNC + end +end -update_coefficients!(L,u,p,t) = nothing -update_coefficients(L,u,p,t) = L -function update_coefficients!(L::AbstractSciMLOperator, u, p, t) +update_coefficients!(L,u,p,t; kwargs...) = nothing +update_coefficients(L,u,p,t; kwargs...) = L +function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...) for op in getops(L) - update_coefficients!(op, u, p, t) + update_coefficients!(op, u, p, t; kwargs...) end nothing end -(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u) -(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u)) +(L::AbstractSciMLOperator)(u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); L * u) +(L::AbstractSciMLOperator)(du, u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u)) ### # caching interface diff --git a/src/matrix.jl b/src/matrix.jl index 4c35b5cf..8a50f505 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -1,18 +1,20 @@ # """ - MatrixOperator(A[; update_func]) + MatrixOperator(A; update_func=nothing, accepted_kwarg_fields=()) Represents a time-dependent linear operator given by an AbstractMatrix. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(A::AbstractMatrix,u,p,t) -> [modifies A] + update_func(A::AbstractMatrix,u,p,t; ) -> [modifies A] """ struct MatrixOperator{T,AType<:AbstractMatrix{T},F} <: AbstractSciMLOperator{T} A::AType update_func::F - MatrixOperator(A::AType; update_func=DEFAULT_UPDATE_FUNC) where{AType} = - new{eltype(A),AType,typeof(update_func)}(A, update_func) + function MatrixOperator(A::AType; update_func=nothing, accepted_kwarg_fields=()) where {AType} + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + new{eltype(A),AType,typeof(_update_func)}(A, _update_func) + end end # constructors @@ -39,21 +41,21 @@ for op in ( if isconstant(L) MatrixOperator($op(L.A)) else - update_func = (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t)) + update_func = (A,u,p,t; kwargs...) -> $op(L.update_func($op(L.A),u,p,t; kwargs...)) MatrixOperator($op(L.A); update_func = update_func) end end end Base.conj(L::MatrixOperator) = MatrixOperator( conj(L.A); - update_func= (A,u,b,t) -> conj(L.update_func(conj(L.A),u,p,t)) + update_func= (A,u,p,t; kwargs...) -> conj(L.update_func(conj(L.A),u,p,t; kwargs...)) ) has_adjoint(A::MatrixOperator) = has_adjoint(A.A) -update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing) +update_coefficients!(L::MatrixOperator,u,p,t; kwargs...) = (L.update_func(L.A,u,p,t; kwargs...); nothing) getops(L::MatrixOperator) = (L.A) -isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC +isconstant(L::MatrixOperator) = update_func_isconstant(L.update_func) Base.iszero(L::MatrixOperator) = iszero(L.A) SparseArrays.sparse(L::MatrixOperator) = sparse(L.A) @@ -88,13 +90,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u) """ - DiagonalOperator(diag, [; update_func]) + DiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=()) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(diag::AbstractVector,u,p,t) -> [modifies diag] + update_func(diag::AbstractVector,u,p,t; ) -> [modifies diag] When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)` can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u` @@ -105,11 +107,12 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)` with leading length `size(u, 1) = N`. """ -function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC) - diag_update_func = if update_func == DEFAULT_UPDATE_FUNC - DEFAULT_UPDATE_FUNC +function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwarg_fields=()) + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + diag_update_func = if update_func_isconstant(_update_func) + _update_func else - (A, u, p, t) -> (update_func(A.diag, u, p, t); A) + (A, u, p, t; kwargs...) -> (_update_func(A.diag, u, p, t; kwargs...); A) end MatrixOperator(Diagonal(diag); update_func=diag_update_func) end @@ -202,13 +205,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) """ - L = AffineOperator(A, B, b[; update_func]) + L = AffineOperator(A, B, b; update_func=nothing, accepted_kwarg_fields=()) L(u) = A*u + B*b Represents a time-dependent affine operator. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(b::AbstractArray,u,p,t) -> [modifies b] + update_func(b::AbstractArray,u,p,t; ) -> [modifies b] """ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T} A::AType @@ -236,44 +239,52 @@ end function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, B::Union{AbstractMatrix,AbstractSciMLOperator}, b::AbstractArray; - update_func = DEFAULT_UPDATE_FUNC, + update_func=nothing, + accepted_kwarg_fields=() ) @assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors of same size" + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + A = A isa AbstractMatrix ? MatrixOperator(A) : A B = B isa AbstractMatrix ? MatrixOperator(B) : B cache = B * b - AffineOperator(A, B, b, cache, update_func) + AffineOperator(A, B, b, cache, _update_func) end """ - L = AddVector(b[; update_func]) + L = AddVector(b; update_func=nothing, accepted_kwarg_fields=()) L(u) = u + b """ -function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC) +function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=()) + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + N = size(b, 1) Id = IdentityOperator(N) - AffineOperator(Id, Id, b; update_func=update_func) + AffineOperator(Id, Id, b; update_func=_update_func) end """ - L = AddVector(B, b[; update_func]) + L = AddVector(B, b; update_func=nothing, accepted_kwarg_fields=()) L(u) = u + B*b """ -function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC) +function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=()) + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + N = size(B, 1) Id = IdentityOperator(N) - AffineOperator(Id, B, b; update_func=update_func) + AffineOperator(Id, B, b; update_func=_update_func) end getops(L::AffineOperator) = (L.A, L.B, L.b) -update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing) -isconstant(L::AffineOperator) = (L.update_func == DEFAULT_UPDATE_FUNC) & all(isconstant, (L.A, L.B)) +update_coefficients!(L::AffineOperator,u,p,t; kwargs...) = (L.update_func(L.b,u,p,t; kwargs...); nothing) +isconstant(L::AffineOperator) = update_func_isconstant(L.update_func) & all(isconstant, (L.A, L.B)) + islinear(::AffineOperator) = false Base.size(L::AffineOperator) = size(L.A) diff --git a/src/scalar.jl b/src/scalar.jl index c4f7989b..f53dd8bf 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -90,7 +90,7 @@ end Base.:+(α::AbstractSciMLScalarOperator) = α """ - ScalarOperator(val[; update_func]) + ScalarOperator(val; update_func=nothing, accepted_kwarg_fields=()) (α::ScalarOperator)(a::Number) = α * a @@ -98,14 +98,16 @@ Represents a time-dependent scalar/scaling operator. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(oldval,u,p,t) -> newval + update_func(oldval,u,p,t; ) -> newval """ mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T} val::T update_func::F - ScalarOperator(val::T; update_func=DEFAULT_UPDATE_FUNC) where{T} = - new{T,typeof(update_func)}(val, update_func) + function ScalarOperator(val::T; update_func=nothing, accepted_kwarg_fields=()) where {T} + _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + new{T,typeof(_update_func)}(val, _update_func) + end end # constructors @@ -118,7 +120,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) # traits function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) - update_func = (oldval,u,p,t) -> α.update_func(oldval |> conj,u,p,t) |> conj + update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj ScalarOperator(val; update_func=update_func) end @@ -132,11 +134,11 @@ Base.abs(α::ScalarOperator) = abs(α.val) Base.iszero(α::ScalarOperator) = iszero(α.val) getops(α::ScalarOperator) = (α.val,) -isconstant(α::ScalarOperator) = α.update_func == DEFAULT_UPDATE_FUNC +isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func) has_ldiv(α::ScalarOperator) = !iszero(α.val) has_ldiv!(α::ScalarOperator) = has_ldiv(α) -update_coefficients!(L::ScalarOperator,u,p,t) = (L.val = L.update_func(L.val,u,p,t); nothing) +update_coefficients!(L::ScalarOperator,u,p,t; kwargs...) = (L.val = L.update_func(L.val,u,p,t; kwargs...); nothing) """ Lazy addition of Scalar Operators diff --git a/src/utils.jl b/src/utils.jl index 213de5ee..fc5e961d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,4 +12,14 @@ end dims(A) = length(size(A)) dims(::AbstractArray{<:Any,N}) where{N} = N dims(::AbstractSciMLOperator) = 2 + +# Keyword argument filtering +struct FilterKwargs{F,K} + f::F + accepted_kwarg_fields::K +end +function (f_filter::FilterKwargs)(args...; kwargs...) + filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwarg_fields) + f_filter.f(args...; filtered_kwargs...) +end # diff --git a/test/scalar.jl b/test/scalar.jl index 425c231c..84e9ad44 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -82,5 +82,16 @@ end @test num(v,u,p,t) ≈ val * u @test convert(Number, num) ≈ val + + # Test scalar operator which expects keyword argument to update, modeled in the style of a DiffEq W-operator. + γ = ScalarOperator(0.0; update_func=(args...; dtgamma) -> dtgamma, accepted_kwarg_fields=(:dtgamma,)) + + dtgamma = rand() + @test γ(u,p,t; dtgamma) ≈ dtgamma * u + @test γ(v,u,p,t; dtgamma) ≈ dtgamma * u + + γ_added = γ + α + @test γ_added(u,p,t; dtgamma) ≈ (dtgamma + p) * u + @test γ_added(v,u,p,t; dtgamma) ≈ (dtgamma + p) * u end # From 293d5eb0c4925dff9825bb2ab29daebcc60b8aac Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 12:24:39 -0400 Subject: [PATCH 02/23] Rename accepted_kwarg_fields -> accepted_kwargs --- docs/src/interface.md | 2 +- src/batch.jl | 10 +++++----- src/interface.jl | 4 ++-- src/matrix.jl | 30 +++++++++++++++--------------- src/scalar.jl | 6 +++--- src/utils.jl | 4 ++-- test/scalar.jl | 2 +- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/src/interface.md b/docs/src/interface.md index 001175e8..cd0f2374 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -61,4 +61,4 @@ matrix-free representations, hence their support in the SciMLOperators interface In rare cases, an operator may be used in a context where additional state is expected to be provided to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator. -For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwarg_fields` argument that defaults to an empty tuple. +For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument that defaults to an empty tuple. diff --git a/src/batch.jl b/src/batch.jl index 680d2226..3942928d 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -1,6 +1,6 @@ # """ - BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=()) + BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwargs=()) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. Acts on `AbstractArray`s of the same size as `diag`. The update function is called @@ -15,9 +15,9 @@ struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T} function BatchedDiagonalOperator( diag::AbstractArray; update_func=nothing, - accepted_kwarg_fields=() + accepted_kwargs=() ) - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + _update_func = preprocess_update_func(update_func, accepted_kwargs) new{ eltype(diag), typeof(diag), @@ -28,8 +28,8 @@ struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T} end end -function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwarg_fields=()) - BatchedDiagonalOperator(u; update_func, accepted_kwarg_fields) +function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwargs=()) + BatchedDiagonalOperator(u; update_func, accepted_kwargs) end # traits diff --git a/src/interface.jl b/src/interface.jl index 73d6dd9a..321b1fdb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -17,9 +17,9 @@ function (::AbstractSciMLOperator) end # Utilities for update functions DEFAULT_UPDATE_FUNC(A,u,p,t) = A -function preprocess_update_func(update_func, accepted_kwarg_fields) +function preprocess_update_func(update_func, accepted_kwargs) update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func - return FilterKwargs(update_func, accepted_kwarg_fields) + return FilterKwargs(update_func, accepted_kwargs) end function update_func_isconstant(update_func) if update_func isa FilterKwargs diff --git a/src/matrix.jl b/src/matrix.jl index 8a50f505..7359e67d 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -1,6 +1,6 @@ # """ - MatrixOperator(A; update_func=nothing, accepted_kwarg_fields=()) + MatrixOperator(A; update_func=nothing, accepted_kwargs=()) Represents a time-dependent linear operator given by an AbstractMatrix. The update function is called by `update_coefficients!` and is assumed to have @@ -11,8 +11,8 @@ the following signature: struct MatrixOperator{T,AType<:AbstractMatrix{T},F} <: AbstractSciMLOperator{T} A::AType update_func::F - function MatrixOperator(A::AType; update_func=nothing, accepted_kwarg_fields=()) where {AType} - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + function MatrixOperator(A::AType; update_func=nothing, accepted_kwargs=()) where {AType} + _update_func = preprocess_update_func(update_func, accepted_kwargs) new{eltype(A),AType,typeof(_update_func)}(A, _update_func) end end @@ -90,7 +90,7 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u) """ - DiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=()) + DiagonalOperator(diag; update_func=nothing, accepted_kwargs=()) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. The update function is called by `update_coefficients!` and is assumed to have @@ -107,8 +107,8 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)` with leading length `size(u, 1) = N`. """ -function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwarg_fields=()) - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) +function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwargs=()) + _update_func = preprocess_update_func(update_func, accepted_kwargs) diag_update_func = if update_func_isconstant(_update_func) _update_func else @@ -205,7 +205,7 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) """ - L = AffineOperator(A, B, b; update_func=nothing, accepted_kwarg_fields=()) + L = AffineOperator(A, B, b; update_func=nothing, accepted_kwargs=()) L(u) = A*u + B*b Represents a time-dependent affine operator. The update function is called @@ -240,12 +240,12 @@ function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, B::Union{AbstractMatrix,AbstractSciMLOperator}, b::AbstractArray; update_func=nothing, - accepted_kwarg_fields=() + accepted_kwargs=() ) @assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors of same size" - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + _update_func = preprocess_update_func(update_func, accepted_kwargs) A = A isa AbstractMatrix ? MatrixOperator(A) : A B = B isa AbstractMatrix ? MatrixOperator(B) : B @@ -255,11 +255,11 @@ function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, end """ - L = AddVector(b; update_func=nothing, accepted_kwarg_fields=()) + L = AddVector(b; update_func=nothing, accepted_kwargs=()) L(u) = u + b """ -function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=()) - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) +function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=()) + _update_func = preprocess_update_func(update_func, accepted_kwargs) N = size(b, 1) Id = IdentityOperator(N) @@ -268,11 +268,11 @@ function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fiel end """ - L = AddVector(B, b; update_func=nothing, accepted_kwarg_fields=()) + L = AddVector(B, b; update_func=nothing, accepted_kwargs=()) L(u) = u + B*b """ -function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=()) - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) +function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=()) + _update_func = preprocess_update_func(update_func, accepted_kwargs) N = size(B, 1) Id = IdentityOperator(N) diff --git a/src/scalar.jl b/src/scalar.jl index f53dd8bf..22532449 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -90,7 +90,7 @@ end Base.:+(α::AbstractSciMLScalarOperator) = α """ - ScalarOperator(val; update_func=nothing, accepted_kwarg_fields=()) + ScalarOperator(val; update_func=nothing, accepted_kwargs=()) (α::ScalarOperator)(a::Number) = α * a @@ -104,8 +104,8 @@ mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T} val::T update_func::F - function ScalarOperator(val::T; update_func=nothing, accepted_kwarg_fields=()) where {T} - _update_func = preprocess_update_func(update_func, accepted_kwarg_fields) + function ScalarOperator(val::T; update_func=nothing, accepted_kwargs=()) where {T} + _update_func = preprocess_update_func(update_func, accepted_kwargs) new{T,typeof(_update_func)}(val, _update_func) end end diff --git a/src/utils.jl b/src/utils.jl index fc5e961d..6f972d86 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,10 +16,10 @@ dims(::AbstractSciMLOperator) = 2 # Keyword argument filtering struct FilterKwargs{F,K} f::F - accepted_kwarg_fields::K + accepted_kwargs::K end function (f_filter::FilterKwargs)(args...; kwargs...) - filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwarg_fields) + filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs) f_filter.f(args...; filtered_kwargs...) end # diff --git a/test/scalar.jl b/test/scalar.jl index 84e9ad44..e71288a0 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -84,7 +84,7 @@ end @test convert(Number, num) ≈ val # Test scalar operator which expects keyword argument to update, modeled in the style of a DiffEq W-operator. - γ = ScalarOperator(0.0; update_func=(args...; dtgamma) -> dtgamma, accepted_kwarg_fields=(:dtgamma,)) + γ = ScalarOperator(0.0; update_func=(args...; dtgamma) -> dtgamma, accepted_kwargs=(:dtgamma,)) dtgamma = rand() @test γ(u,p,t; dtgamma) ≈ dtgamma * u From 1712594d554e6780f368ce47d4dcbfa88f4e0f81 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 13:20:27 -0400 Subject: [PATCH 03/23] Allow accepted_kwargs=nothing to indicate no wrapping --- src/interface.jl | 4 +++- src/matrix.jl | 10 +++++----- src/scalar.jl | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 321b1fdb..85ca894b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -19,7 +19,9 @@ function (::AbstractSciMLOperator) end DEFAULT_UPDATE_FUNC(A,u,p,t) = A function preprocess_update_func(update_func, accepted_kwargs) update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func - return FilterKwargs(update_func, accepted_kwargs) + # accepted_kwargs can be passed as nothing to indicate that we should not filter + # (e.g. if the function already accepts all kwargs...). + return (accepted_kwargs === nothing) ? update_func : FilterKwargs(update_func, accepted_kwargs) end function update_func_isconstant(update_func) if update_func isa FilterKwargs diff --git a/src/matrix.jl b/src/matrix.jl index 7359e67d..1fdf3ce9 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -42,7 +42,7 @@ for op in ( MatrixOperator($op(L.A)) else update_func = (A,u,p,t; kwargs...) -> $op(L.update_func($op(L.A),u,p,t; kwargs...)) - MatrixOperator($op(L.A); update_func = update_func) + MatrixOperator($op(L.A); update_func = update_func, accepted_kwargs=nothing) end end end @@ -79,7 +79,7 @@ Base.copyto!(L::MatrixOperator, rhs::Base.Broadcast.Broadcasted{<:StaticArraysCo Base.Broadcast.broadcastable(L::MatrixOperator) = L Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType) ArrayInterface.issingular(L::MatrixOperator) = ArrayInterface.issingular(L.A) -Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func) +Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func, accepted_kwargs=nothing) # operator application Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u @@ -114,7 +114,7 @@ function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kw else (A, u, p, t; kwargs...) -> (_update_func(A.diag, u, p, t; kwargs...); A) end - MatrixOperator(Diagonal(diag); update_func=diag_update_func) + MatrixOperator(Diagonal(diag); update_func=diag_update_func, accepted_kwargs=nothing) end LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A)) @@ -264,7 +264,7 @@ function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=()) N = size(b, 1) Id = IdentityOperator(N) - AffineOperator(Id, Id, b; update_func=_update_func) + AffineOperator(Id, Id, b; update_func=_update_func, accepted_kwargs=nothing) end """ @@ -277,7 +277,7 @@ function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwargs= N = size(B, 1) Id = IdentityOperator(N) - AffineOperator(Id, B, b; update_func=_update_func) + AffineOperator(Id, B, b; update_func=_update_func, accepted_kwargs=nothing) end getops(L::AffineOperator) = (L.A, L.B, L.b) diff --git a/src/scalar.jl b/src/scalar.jl index 22532449..830a455c 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -121,7 +121,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj - ScalarOperator(val; update_func=update_func) + ScalarOperator(val; update_func=update_func, accepted_kwargs=nothing) end Base.one(::AbstractSciMLScalarOperator{T}) where{T} = ScalarOperator(one(T)) From cd705035a765e547c2451b5be1c7d6165ea612f5 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 13:25:27 -0400 Subject: [PATCH 04/23] Tweak keyword filtering logic --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 6f972d86..f5a3314d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -19,7 +19,9 @@ struct FilterKwargs{F,K} accepted_kwargs::K end function (f_filter::FilterKwargs)(args...; kwargs...) - filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs) + # Filter keyword arguments to those accepted by function. + # Avoid throwing errors here if a keyword argument is not provided: defer this to the function call for a more readable error. + filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs if haskey(kwargs, kwarg)) f_filter.f(args...; filtered_kwargs...) end # From 17298fb6b5a151b48a7bd0a472ffd5ecedbc2e95 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 13:02:30 -0400 Subject: [PATCH 05/23] Test operator update (including kwarg update) in operator algebra test --- test/total.jl | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/test/total.jl b/test/total.jl index 4fb44a92..f0591664 100644 --- a/test/total.jl +++ b/test/total.jl @@ -73,20 +73,31 @@ end @testset "Operator Algebra" begin N2 = N*N + A = rand(N,N) - B = rand(N,N) + # Introduce update function for B + B = MatrixOperator(zeros(N,N); update_func=(A, u, p, t) -> (A .= p)) C = rand(N,N) - D = rand(N,N) + # Introduce update function for D dependent on kwarg "matrix" + D = MatrixOperator(zeros(N,N); update_func=(A, u, p, t; matrix) -> (A .= p*t*matrix), + accepted_kwargs=(:matrix,)) u = rand(N2,K) + p = rand() + t = rand() + matrix = rand(N, N) + diag = rand(N2) α = rand() β = rand() T1 = ⊗(A, B) T2 = ⊗(C, D) - D1 = DiagonalOperator(rand(N2)) - D2 = DiagonalOperator(rand(N2)) + # Introduce update function for D1 + D1 = DiagonalOperator(p * ones(N2); update_func=(A, u, p, t) -> (A .= p)) + # Introduce update funcion for D2 dependent on kwarg "diag" + D2 = DiagonalOperator(p*t * diag; update_func=(A, u, p, t; diag) -> (A .= p*t*diag), + accepted_kwargs=(:diag,)) TT = [T1, T2] DD = Diagonal([D1, D2]) @@ -94,7 +105,19 @@ end op = TT' * DD * TT op = cache_operator(op, u) + # Update operator + @test_nowarn update_coefficients!(op, u, p, t; diag, matrix) + # Form dense operator manually + dense_T1 = kron(A, p * ones(N, N)) + dense_T2 = kron(C, (p*t) .* matrix) + dense_DD = Diagonal(vcat(p * ones(N2), p*t*diag)) + dense_op = hcat(dense_T1', dense_T2') * dense_DD * vcat(dense_T1, dense_T2) + # Test correctness of op + @test op * u ≈ dense_op * u + @test convert(AbstractMatrix, op) ≈ dense_op + # Test consistency with three-arg mul! v=rand(N2,K); @test mul!(v, op, u) ≈ op * u + # Test consistency with in-place five-arg mul! v=rand(N2,K); w=copy(v); @test mul!(v, op, u, α, β) ≈ α*(op * u) + β * w end -# +# \ No newline at end of file From dc3e31d870ff4edf14b0bcf6de7062ca028f9225 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 15:06:21 -0400 Subject: [PATCH 06/23] Support kwargs in function operator --- src/func.jl | 24 +++++++++++++++++------- test/func.jl | 14 ++++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/func.jl b/src/func.jl index 9a08bfbc..7f6402ca 100644 --- a/src/func.jl +++ b/src/func.jl @@ -2,7 +2,7 @@ """ Matrix free operators (given by a function) """ -mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T} +mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: AbstractSciMLOperator{T} """ Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """ op::F """ Adjoint operator""" @@ -17,6 +17,8 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst p::P """ Time """ t::Tt + """ Keyword arguments """ + kwargs::K """ Cache """ cache::C @@ -28,6 +30,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst traits, p, t, + kwargs_for_op, cache ) @@ -46,6 +49,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst typeof(traits), typeof(p), typeof(t), + typeof(kwargs_for_op), typeof(cache), }( op, @@ -55,6 +59,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst traits, p, t, + kwargs_for_op, cache, ) end @@ -82,6 +87,7 @@ function FunctionOperator(op, FunctionOperator(op, input, output; kwargs...) end +# TODO: document constructor and revisit design as needed (e.g. for "kwargs_for_op") function FunctionOperator(op, input::AbstractVecOrMat, output::AbstractVecOrMat = input; @@ -96,6 +102,7 @@ function FunctionOperator(op, p=nothing, t::Union{Number,Nothing}=nothing, + kwargs_for_op=(), ifcache::Bool = true, @@ -169,7 +176,9 @@ function FunctionOperator(op, traits, p, t, - cache, + # automatically convert NamedTuple's to pairs + pairs(kwargs_for_op), + cache ) ifcache ? cache_operator(L, input, output) : L @@ -200,6 +209,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) L.p = p L.t = t + L.kwargs = kwargs nothing end @@ -326,22 +336,22 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin # operator application Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t) -Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t) +Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t; L.kwargs...) function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat) _, co = L.cache du = zero(co) - L.op(du, u, L.p, L.t) + L.op(du, u, L.p, L.t; L.kwargs...) end function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat) ci, _ = L.cache du = zero(ci) - L.op_inverse(du, u, L.p, L.t) + L.op_inverse(du, u, L.p, L.t; L.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op(v, u, L.p, L.t) + L.op(v, u, L.p, L.t; L.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...) @@ -358,7 +368,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::A end function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op_inverse(v, u, L.p, L.t) + L.op_inverse(v, u, L.p, L.t; L.kwargs...) end function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat) diff --git a/test/func.jl b/test/func.jl index 93187538..d1326a62 100644 --- a/test/func.jl +++ b/test/func.jl @@ -94,18 +94,20 @@ end u = rand(N,K) p = rand(N) t = rand() + scale = rand() - f(du,u,p,t) = mul!(du, Diagonal(p*t), u) + # Accept a kwarg "scale" in operator action + f(du,u,p,t; scale) = begin @show scale; mul!(du, Diagonal(p*t*scale), u) end - L = FunctionOperator(f, u, u; p=zero(p), t=zero(t)) + L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), kwargs_for_op=(;scale=zero(scale))) - ans = @. u * p * t - @test L(u,p,t) ≈ ans - v=copy(u); @test L(v,u,p,t) ≈ ans + ans = @. u * p * t * scale + @test L(u,p,t; scale) ≈ ans + v=copy(u); @test L(v,u,p,t; scale) ≈ ans # test that output isn't accidentally mutated by passing an internal cache. - A = Diagonal(p * t) + A = Diagonal(p * t * scale) u1 = rand(N, K) u2 = rand(N, K) From 5b47ef97a5a38d9ee319e75a6812bfba8e32d76d Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 15:36:48 -0400 Subject: [PATCH 07/23] Propagate kwargs for out-of-place function operator update_coefficients too --- src/func.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/func.jl b/src/func.jl index 7f6402ca..d3835c63 100644 --- a/src/func.jl +++ b/src/func.jl @@ -184,11 +184,10 @@ function FunctionOperator(op, ifcache ? cache_operator(L, input, output) : L end -function update_coefficients(L::FunctionOperator, u, p, t) - op = update_coefficients(L.op, u, p, t) - op_adjoint = update_coefficients(L.op_adjoint, u, p, t) - op_inverse = update_coefficients(L.op_inverse, u, p, t) - op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t) +function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) + for op in getops(L) + op = update_coefficients(op, u, p, t; kwargs...) + end FunctionOperator(op, op_adjoint, @@ -197,6 +196,7 @@ function update_coefficients(L::FunctionOperator, u, p, t) L.traits, p, t, + kwargs_for_op=kwargs, L.cache ) end From 88d5050b754e3e022e1b2bed2558422053444aad Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 15:49:03 -0400 Subject: [PATCH 08/23] Catch function operator error for empty kwargs --- src/func.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/func.jl b/src/func.jl index d3835c63..a5add964 100644 --- a/src/func.jl +++ b/src/func.jl @@ -177,7 +177,7 @@ function FunctionOperator(op, p, t, # automatically convert NamedTuple's to pairs - pairs(kwargs_for_op), + Base.Pairs{Symbol}(kwargs_for_op, keys(kwargs_for_op)), cache ) From c7fcd515f681d20ea5f1bccb99c54a5a10c4f2d1 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 15:51:10 -0400 Subject: [PATCH 09/23] Address code review suggestions on diag op construction --- test/total.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/total.jl b/test/total.jl index f0591664..84fa4a99 100644 --- a/test/total.jl +++ b/test/total.jl @@ -94,9 +94,9 @@ end T2 = ⊗(C, D) # Introduce update function for D1 - D1 = DiagonalOperator(p * ones(N2); update_func=(A, u, p, t) -> (A .= p)) + D1 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t) -> (d .= p)) # Introduce update funcion for D2 dependent on kwarg "diag" - D2 = DiagonalOperator(p*t * diag; update_func=(A, u, p, t; diag) -> (A .= p*t*diag), + D2 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t; diag) -> (d .= p*t*diag), accepted_kwargs=(:diag,)) TT = [T1, T2] From 24aef00d5e2883905aa068d2564bb92cef1aa73a Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 15:59:57 -0400 Subject: [PATCH 10/23] Improve logic for normalizing kwargs --- src/func.jl | 5 ++--- src/utils.jl | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/func.jl b/src/func.jl index a5add964..61ae8757 100644 --- a/src/func.jl +++ b/src/func.jl @@ -176,8 +176,7 @@ function FunctionOperator(op, traits, p, t, - # automatically convert NamedTuple's to pairs - Base.Pairs{Symbol}(kwargs_for_op, keys(kwargs_for_op)), + normalize_kwargs(kwargs_for_op), cache ) @@ -209,7 +208,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) L.p = p L.t = t - L.kwargs = kwargs + L.kwargs = normalize_kwargs(kwargs) nothing end diff --git a/src/utils.jl b/src/utils.jl index f5a3314d..ac934ca7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,4 +24,7 @@ function (f_filter::FilterKwargs)(args...; kwargs...) filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs if haskey(kwargs, kwarg)) f_filter.f(args...; filtered_kwargs...) end +# automatically convert NamedTuple's, etc. to a normalized kwargs representation (i.e. Base.Pairs) +normalize_kwargs(; kwargs...) = kwargs +normalize_kwargs(kwargs) = normalize_kwargs(; kwargs...) # From 44d66bb871dd7d0a28c6613520e37d71135b8c8f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 16:03:22 -0400 Subject: [PATCH 11/23] Test operator application form in operator algebra test set --- test/total.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/total.jl b/test/total.jl index 84fa4a99..50544012 100644 --- a/test/total.jl +++ b/test/total.jl @@ -119,5 +119,8 @@ end v=rand(N2,K); @test mul!(v, op, u) ≈ op * u # Test consistency with in-place five-arg mul! v=rand(N2,K); w=copy(v); @test mul!(v, op, u, α, β) ≈ α*(op * u) + β * w + # Test consistency with operator application form + @test op(u, p, t; diag, matrix) ≈ op * u + v=rand(N2,K); @test op(v, u, p, t; diag, matrix) ≈ op * u end # \ No newline at end of file From 9d8fdff31145e6f99bdfbe5b312b274325a4e602 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 16:13:39 -0400 Subject: [PATCH 12/23] Support kwargs in function operator functionals --- src/func.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/func.jl b/src/func.jl index 61ae8757..80eb22c3 100644 --- a/src/func.jl +++ b/src/func.jl @@ -238,9 +238,6 @@ function Base.adjoint(L::FunctionOperator) traits = L.traits @set! traits.size = reverse(size(L)) - p = L.p - t = L.t - cache = if iscached(L) cache = reverse(L.cache) else @@ -252,8 +249,9 @@ function Base.adjoint(L::FunctionOperator) op_inverse, op_adjoint_inverse, traits, - p, - t, + L.p, + L.t, + L.kwargs, cache, ) end @@ -280,9 +278,6 @@ function Base.inv(L::FunctionOperator) (p::Real) -> 1 / traits.opnorm(p) end - p = L.p - t = L.t - cache = if iscached(L) cache = reverse(L.cache) else @@ -294,8 +289,9 @@ function Base.inv(L::FunctionOperator) op_inverse, op_adjoint_inverse, traits, - p, - t, + L.p, + L.t, + L.kwargs, cache, ) end From f793c60b0af4567673279301852880eeae6f4bee Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 16:25:49 -0400 Subject: [PATCH 13/23] Add example --- docs/src/interface.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/src/interface.md b/docs/src/interface.md index cd0f2374..272a7a34 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -62,3 +62,19 @@ In rare cases, an operator may be used in a context where additional state is ex to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator. For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument that defaults to an empty tuple. + +In the below example, we create an operator that gleefully ignores `u`, `p`, and `t` and uses its own special scaling. +```@example +using SciMLOperators + +γ = ScalarOperator(0.0; update_func=(a, u, p, t; my_special_scaling) -> my_special_scaling, + accepted_kwargs=(:my_special_scaling,)) + +# Update coefficients, then apply operator +update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling=7.0) +@show γ * [2.0] + +# Use operator application form +@show γ([2.0], nothing, nothing; my_special_scaling = 5.0) +nothing # hide +``` \ No newline at end of file From 331296742180096e6dbad834d33e2d41be694aef Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 16:53:08 -0400 Subject: [PATCH 14/23] Remove unncessary function call --- src/func.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/func.jl b/src/func.jl index 80eb22c3..086004a3 100644 --- a/src/func.jl +++ b/src/func.jl @@ -208,7 +208,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) L.p = p L.t = t - L.kwargs = normalize_kwargs(kwargs) + L.kwargs = kwargs nothing end From 69f0eccbe52617830db3e7fe504ac6270c48067c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 17:19:29 -0400 Subject: [PATCH 15/23] Rename kwargs_for_op -> accepted_kwargs --- src/func.jl | 14 +++++++------- test/func.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/func.jl b/src/func.jl index 086004a3..07207c6e 100644 --- a/src/func.jl +++ b/src/func.jl @@ -30,7 +30,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: Ab traits, p, t, - kwargs_for_op, + accepted_kwargs, cache ) @@ -49,7 +49,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: Ab typeof(traits), typeof(p), typeof(t), - typeof(kwargs_for_op), + typeof(accepted_kwargs), typeof(cache), }( op, @@ -59,7 +59,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: Ab traits, p, t, - kwargs_for_op, + accepted_kwargs, cache, ) end @@ -87,7 +87,7 @@ function FunctionOperator(op, FunctionOperator(op, input, output; kwargs...) end -# TODO: document constructor and revisit design as needed (e.g. for "kwargs_for_op") +# TODO: document constructor and revisit design as needed (e.g. for "accepted_kwargs") function FunctionOperator(op, input::AbstractVecOrMat, output::AbstractVecOrMat = input; @@ -102,7 +102,7 @@ function FunctionOperator(op, p=nothing, t::Union{Number,Nothing}=nothing, - kwargs_for_op=(), + accepted_kwargs=(), ifcache::Bool = true, @@ -176,7 +176,7 @@ function FunctionOperator(op, traits, p, t, - normalize_kwargs(kwargs_for_op), + normalize_kwargs(accepted_kwargs), cache ) @@ -195,7 +195,7 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) L.traits, p, t, - kwargs_for_op=kwargs, + accepted_kwargs=kwargs, L.cache ) end diff --git a/test/func.jl b/test/func.jl index d1326a62..2b7dc861 100644 --- a/test/func.jl +++ b/test/func.jl @@ -99,7 +99,7 @@ end # Accept a kwarg "scale" in operator action f(du,u,p,t; scale) = begin @show scale; mul!(du, Diagonal(p*t*scale), u) end - L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), kwargs_for_op=(;scale=zero(scale))) + L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), accepted_kwargs=(;scale=zero(scale))) ans = @. u * p * t * scale @test L(u,p,t; scale) ≈ ans From fef7618fa74c905ed110fea4468ce6af992bd2a6 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 17:22:06 -0400 Subject: [PATCH 16/23] Fix function operator out-of-place update coefficients --- src/func.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/func.jl b/src/func.jl index 07207c6e..5aed7108 100644 --- a/src/func.jl +++ b/src/func.jl @@ -184,9 +184,10 @@ function FunctionOperator(op, end function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) - for op in getops(L) - op = update_coefficients(op, u, p, t; kwargs...) - end + op = update_coefficients(L.op, u, p, t; kwargs...) + op_adjoint = update_coefficients(L.op_adjoint, u, p, t; kwargs...) + op_inverse = update_coefficients(L.op_inverse, u, p, t; kwargs...) + op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t; kwargs...) FunctionOperator(op, op_adjoint, From ceca67ccd59250f8979c165b9fe2f36a7e7ad2d3 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 17:32:12 -0400 Subject: [PATCH 17/23] Use NoKwargFilter() to bypass keyword filtering (rather than nothing) --- docs/src/interface.md | 2 +- src/interface.jl | 11 +++++++++-- src/matrix.jl | 36 ++++++++++++++++++------------------ src/scalar.jl | 6 +++--- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/docs/src/interface.md b/docs/src/interface.md index 272a7a34..f433804c 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -61,7 +61,7 @@ matrix-free representations, hence their support in the SciMLOperators interface In rare cases, an operator may be used in a context where additional state is expected to be provided to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator. -For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument that defaults to an empty tuple. +For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument (by default, none are passed). In the below example, we create an operator that gleefully ignores `u`, `p`, and `t` and uses its own special scaling. ```@example diff --git a/src/interface.jl b/src/interface.jl index 85ca894b..67471245 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -15,13 +15,20 @@ out-of-place form B = update_coefficients(A,u,p,t). """ function (::AbstractSciMLOperator) end +### # Utilities for update functions +### + DEFAULT_UPDATE_FUNC(A,u,p,t) = A + +struct NoKwargFilter end + function preprocess_update_func(update_func, accepted_kwargs) - update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func + _update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func + _accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs # accepted_kwargs can be passed as nothing to indicate that we should not filter # (e.g. if the function already accepts all kwargs...). - return (accepted_kwargs === nothing) ? update_func : FilterKwargs(update_func, accepted_kwargs) + return (_accepted_kwargs isa NoKwargFilter) ? _update_func : FilterKwargs(_update_func, _accepted_kwargs) end function update_func_isconstant(update_func) if update_func isa FilterKwargs diff --git a/src/matrix.jl b/src/matrix.jl index 1fdf3ce9..b0757992 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -1,17 +1,17 @@ # """ - MatrixOperator(A; update_func=nothing, accepted_kwargs=()) + MatrixOperator(A; [update_func, accepted_kwargs]) Represents a time-dependent linear operator given by an AbstractMatrix. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(A::AbstractMatrix,u,p,t; ) -> [modifies A] + update_func(A::AbstractMatrix,u,p,t; ) -> [modifies A] """ struct MatrixOperator{T,AType<:AbstractMatrix{T},F} <: AbstractSciMLOperator{T} A::AType update_func::F - function MatrixOperator(A::AType; update_func=nothing, accepted_kwargs=()) where {AType} + function MatrixOperator(A::AType; update_func=nothing, accepted_kwargs=nothing) where {AType} _update_func = preprocess_update_func(update_func, accepted_kwargs) new{eltype(A),AType,typeof(_update_func)}(A, _update_func) end @@ -42,7 +42,7 @@ for op in ( MatrixOperator($op(L.A)) else update_func = (A,u,p,t; kwargs...) -> $op(L.update_func($op(L.A),u,p,t; kwargs...)) - MatrixOperator($op(L.A); update_func = update_func, accepted_kwargs=nothing) + MatrixOperator($op(L.A); update_func = update_func, accepted_kwargs=NoKwargFilter()) end end end @@ -79,7 +79,7 @@ Base.copyto!(L::MatrixOperator, rhs::Base.Broadcast.Broadcasted{<:StaticArraysCo Base.Broadcast.broadcastable(L::MatrixOperator) = L Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType) ArrayInterface.issingular(L::MatrixOperator) = ArrayInterface.issingular(L.A) -Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func, accepted_kwargs=nothing) +Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func, accepted_kwargs=NoKwargFilter()) # operator application Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u @@ -90,13 +90,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u) """ - DiagonalOperator(diag; update_func=nothing, accepted_kwargs=()) + DiagonalOperator(diag; [update_func, accepted_kwargs]) Represents a time-dependent elementwise scaling (diagonal-scaling) operation. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(diag::AbstractVector,u,p,t; ) -> [modifies diag] + update_func(diag::AbstractVector,u,p,t; ) -> [modifies diag] When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)` can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u` @@ -107,14 +107,14 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)` with leading length `size(u, 1) = N`. """ -function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwargs=()) +function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwargs=nothing) _update_func = preprocess_update_func(update_func, accepted_kwargs) diag_update_func = if update_func_isconstant(_update_func) _update_func else (A, u, p, t; kwargs...) -> (_update_func(A.diag, u, p, t; kwargs...); A) end - MatrixOperator(Diagonal(diag); update_func=diag_update_func, accepted_kwargs=nothing) + MatrixOperator(Diagonal(diag); update_func=diag_update_func, accepted_kwargs=NoKwargFilter()) end LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A)) @@ -205,13 +205,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) """ - L = AffineOperator(A, B, b; update_func=nothing, accepted_kwargs=()) + L = AffineOperator(A, B, b; [update_func, accepted_kwargs]) L(u) = A*u + B*b Represents a time-dependent affine operator. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(b::AbstractArray,u,p,t; ) -> [modifies b] + update_func(b::AbstractArray,u,p,t; ) -> [modifies b] """ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T} A::AType @@ -240,7 +240,7 @@ function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, B::Union{AbstractMatrix,AbstractSciMLOperator}, b::AbstractArray; update_func=nothing, - accepted_kwargs=() + accepted_kwargs=nothing ) @assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors of same size" @@ -255,29 +255,29 @@ function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator}, end """ - L = AddVector(b; update_func=nothing, accepted_kwargs=()) + L = AddVector(b; [update_func, accepted_kwargs]) L(u) = u + b """ -function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=()) +function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=nothing) _update_func = preprocess_update_func(update_func, accepted_kwargs) N = size(b, 1) Id = IdentityOperator(N) - AffineOperator(Id, Id, b; update_func=_update_func, accepted_kwargs=nothing) + AffineOperator(Id, Id, b; update_func=_update_func, accepted_kwargs=NoKwargFilter()) end """ - L = AddVector(B, b; update_func=nothing, accepted_kwargs=()) + L = AddVector(B, b; [update_func, accepted_kwargs]) L(u) = u + B*b """ -function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=()) +function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwargs=nothing) _update_func = preprocess_update_func(update_func, accepted_kwargs) N = size(B, 1) Id = IdentityOperator(N) - AffineOperator(Id, B, b; update_func=_update_func, accepted_kwargs=nothing) + AffineOperator(Id, B, b; update_func=_update_func, accepted_kwargs=NoKwargFilter()) end getops(L::AffineOperator) = (L.A, L.B, L.b) diff --git a/src/scalar.jl b/src/scalar.jl index 830a455c..57e6523a 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -90,7 +90,7 @@ end Base.:+(α::AbstractSciMLScalarOperator) = α """ - ScalarOperator(val; update_func=nothing, accepted_kwargs=()) + ScalarOperator(val; [update_func, accepted_kwargs]) (α::ScalarOperator)(a::Number) = α * a @@ -98,7 +98,7 @@ Represents a time-dependent scalar/scaling operator. The update function is called by `update_coefficients!` and is assumed to have the following signature: - update_func(oldval,u,p,t; ) -> newval + update_func(oldval,u,p,t; ) -> newval """ mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T} val::T @@ -121,7 +121,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj - ScalarOperator(val; update_func=update_func, accepted_kwargs=nothing) + ScalarOperator(val; update_func=update_func, accepted_kwargs=NoKwargFilter()) end Base.one(::AbstractSciMLScalarOperator{T}) where{T} = ScalarOperator(one(T)) From 0efa2ce095803b0822b1562060d95fa72ab216c1 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 12 Mar 2023 17:33:34 -0400 Subject: [PATCH 18/23] Remove debug line --- test/func.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/func.jl b/test/func.jl index 2b7dc861..8f021074 100644 --- a/test/func.jl +++ b/test/func.jl @@ -97,7 +97,7 @@ end scale = rand() # Accept a kwarg "scale" in operator action - f(du,u,p,t; scale) = begin @show scale; mul!(du, Diagonal(p*t*scale), u) end + f(du,u,p,t; scale) = mul!(du, Diagonal(p*t*scale), u) L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), accepted_kwargs=(;scale=zero(scale))) From e63e4452933a975828eaca9ea536e948f1798b6b Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 27 May 2023 13:25:55 -0400 Subject: [PATCH 19/23] fix diagonaloperator update --- src/matrix.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/matrix.jl b/src/matrix.jl index 2af128f6..15135686 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -155,10 +155,10 @@ function DiagonalOperator(diag::AbstractVector; ) diag_update_func = update_func_isconstant(update_func) ? update_func : - (A, u, p, t; kwargs...) -> (update_func(A.diag, u, p, t; kwargs...); A) + (A, u, p, t; kwargs...) -> update_func(A.diag, u, p, t; kwargs...) |> Diagonal diag_update_func! = update_func_isconstant(update_func!) ? update_func! : - (A, u, p, t; kwargs...) -> (update_func!(A.diag, u, p, t; kwargs...); A) + (A, u, p, t; kwargs...) -> update_func!(A.diag, u, p, t; kwargs...) MatrixOperator(Diagonal(diag); update_func = diag_update_func, From e89d5a1ad32624be78c483e6d71b8922a2416c68 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 29 May 2023 13:22:41 -0400 Subject: [PATCH 20/23] function op working --- src/func.jl | 44 ++++++++++++++++++++++++++------------------ src/utils.jl | 19 +++++++++++-------- test/func.jl | 7 ++++--- test/scalar.jl | 2 +- test/total.jl | 10 ++++------ 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/func.jl b/src/func.jl index eae3c95c..7ca6e1c9 100644 --- a/src/func.jl +++ b/src/func.jl @@ -2,7 +2,7 @@ """ Matrix free operators (given by a function) """ -mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: AbstractSciMLOperator{T} +mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T} """ Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """ op::F """ Adjoint operator""" @@ -17,8 +17,8 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} p::P """ Time """ t::Tt - """ Keyword arguments """ - kwargs::K + """ kwargs """ + kwargs::Dict{Symbol,Any} # TODO move inside traits later """ Cache """ cache::C @@ -30,7 +30,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} traits, p, t, - accepted_kwargs, + kwargs, cache ) @@ -51,7 +51,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} typeof(traits), typeof(p), typeof(t), - typeof(accepted_kwargs), typeof(cache), }( op, @@ -61,7 +60,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} traits, p, t, - accepted_kwargs, + kwargs, cache, ) end @@ -107,7 +106,7 @@ function FunctionOperator(op, p=nothing, t::Union{Number,Nothing}=nothing, - accepted_kwargs = (), + accepted_kwargs::NTuple{N,Symbol} = (), ifcache::Bool = true, @@ -118,13 +117,14 @@ function FunctionOperator(op, issymmetric::Bool = false, ishermitian::Bool = false, isposdef::Bool = false, - ) + ) where{N} # store eltype of input/output for caching with ComposedOperator. eltypes = eltype.((input, output)) sz = (size(output, 1), size(input, 1)) T = isnothing(T) ? promote_type(eltypes...) : T t = isnothing(t) ? zero(real(T)) : t + kwargs = Dict{Symbol, Any}() isinplace = if isnothing(isinplace) static_hasmethod(op, typeof((output, input, p, t))) @@ -188,6 +188,7 @@ function FunctionOperator(op, T = T, size = sz, eltypes = eltypes, + accepted_kwargs = accepted_kwargs, ) L = FunctionOperator( @@ -198,7 +199,7 @@ function FunctionOperator(op, traits, p, t, - normalize_kwargs(accepted_kwargs), + kwargs, cache ) @@ -214,9 +215,10 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) @set! L.p = p @set! L.t = t - isconstant(L) && return L + filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) + @set! L.kwargs = Dict(filtered_kwargs) - filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg)) + isconstant(L) && return L @set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...) @set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...) @@ -226,17 +228,18 @@ end function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) - isconstant(L) && return + L.p = p + L.t = t + + filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) + L.kwargs = Dict(filtered_kwargs) - filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg)) + isconstant(L) && return for op in getops(L) update_coefficients!(op, u, p, t; filtered_kwargs...) end - L.p = p - L.t = t - L end @@ -383,8 +386,13 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin # TODO - FunctionOperator, Base.conj, transpose # operator application -Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t) -Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t; L.kwargs...) +function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} + L.op(u, L.p, L.t; L.kwargs...) +end + +function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} + L.op_inverse(u, L.p, L.t; L.kwargs...) +end function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat) _, co = L.cache diff --git a/src/utils.jl b/src/utils.jl index ac934ca7..92f66552 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,13 +18,16 @@ struct FilterKwargs{F,K} f::F accepted_kwargs::K end -function (f_filter::FilterKwargs)(args...; kwargs...) - # Filter keyword arguments to those accepted by function. - # Avoid throwing errors here if a keyword argument is not provided: defer this to the function call for a more readable error. - filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs if haskey(kwargs, kwarg)) - f_filter.f(args...; filtered_kwargs...) + +# Filter keyword arguments to those accepted by function. +# Avoid throwing errors here if a keyword argument is not provided: defer +# this to the function call for a more readable error. +function get_filtered_kwargs(kwargs::Base.Pairs, accepted_kwargs::NTuple{N,Symbol}) where{N} + (kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw)) +end + +function (f::FilterKwargs)(args...; kwargs...) + filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs) + f.f(args...; filtered_kwargs...) end -# automatically convert NamedTuple's, etc. to a normalized kwargs representation (i.e. Base.Pairs) -normalize_kwargs(; kwargs...) = kwargs -normalize_kwargs(kwargs) = normalize_kwargs(; kwargs...) # diff --git a/test/func.jl b/test/func.jl index ba7572fa..fc56137b 100644 --- a/test/func.jl +++ b/test/func.jl @@ -107,10 +107,11 @@ end scale = rand() # Accept a kwarg "scale" in operator action - f(du,u,p,t; scale) = mul!(du, Diagonal(p*t*scale), u) - f(u, p, t; scale) = Diagonal(p * t * scale) * u + f(du,u,p,t; scale = 1.0) = mul!(du, Diagonal(p*t*scale), u) + f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u - L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), accepted_kwargs=(;scale=zero(scale))) + L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), + accepted_kwargs = (:scale,)) ans = @. u * p * t * scale @test L(u,p,t; scale) ≈ ans diff --git a/test/scalar.jl b/test/scalar.jl index 1c4747c0..da3abaf3 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -151,7 +151,7 @@ end # Test scalar operator which expects keyword argument to update, # modeled in the style of a DiffEq W-operator. γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma, - accepted_kwargs=(:dtgamma,)) + accepted_kwargs = (:dtgamma,)) dtgamma = rand() @test γ(u,p,t; dtgamma) ≈ dtgamma * u diff --git a/test/total.jl b/test/total.jl index fb1a2c04..6dd2f645 100644 --- a/test/total.jl +++ b/test/total.jl @@ -86,7 +86,7 @@ end C = rand(N,N) # Introduce update function for D dependent on kwarg "matrix" D = MatrixOperator(zeros(N,N); update_func=(A, u, p, t; matrix) -> (A .= p*t*matrix), - accepted_kwargs=(:matrix,)) + accepted_kwargs = (:matrix,)) u = rand(N2,K) p = rand() @@ -99,11 +99,9 @@ end T1 = ⊗(A, B) T2 = ⊗(C, D) - # Introduce update function for D1 - D1 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t) -> (d .= p)) - # Introduce update funcion for D2 dependent on kwarg "diag" - D2 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t; diag) -> (d .= p*t*diag), - accepted_kwargs=(:diag,)) + D1 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t) -> p) + D2 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t; diag) -> p*t*diag, + accepted_kwargs = (:diag,)) TT = [T1, T2] DD = Diagonal([D1, D2]) From 4818a66052e472318fc9d99a861b722f5c0d4723 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 29 May 2023 13:43:12 -0400 Subject: [PATCH 21/23] moved kwargs to FunctionOp.traits --- src/func.jl | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/func.jl b/src/func.jl index 7ca6e1c9..58145224 100644 --- a/src/func.jl +++ b/src/func.jl @@ -17,8 +17,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: p::P """ Time """ t::Tt - """ kwargs """ - kwargs::Dict{Symbol,Any} # TODO move inside traits later """ Cache """ cache::C @@ -30,7 +28,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: traits, p, t, - kwargs, cache ) @@ -60,7 +57,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: traits, p, t, - kwargs, cache, ) end @@ -124,7 +120,6 @@ function FunctionOperator(op, sz = (size(output, 1), size(input, 1)) T = isnothing(T) ? promote_type(eltypes...) : T t = isnothing(t) ? zero(real(T)) : t - kwargs = Dict{Symbol, Any}() isinplace = if isnothing(isinplace) static_hasmethod(op, typeof((output, input, p, t))) @@ -189,6 +184,7 @@ function FunctionOperator(op, size = sz, eltypes = eltypes, accepted_kwargs = accepted_kwargs, + kwargs = Dict{Symbol, Any}(), ) L = FunctionOperator( @@ -199,7 +195,6 @@ function FunctionOperator(op, traits, p, t, - kwargs, cache ) @@ -212,11 +207,13 @@ end function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) + # update p, t @set! L.p = p @set! L.t = t + # filter and update kwargs filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) - @set! L.kwargs = Dict(filtered_kwargs) + @set! L.traits.kwargs = Dict{Symbol, Any}(filtered_kwargs) isconstant(L) && return L @@ -228,11 +225,13 @@ end function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) + # update p, t L.p = p L.t = t + # filter and update kwargs filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) - L.kwargs = Dict(filtered_kwargs) + L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs)) isconstant(L) && return @@ -289,7 +288,6 @@ function Base.adjoint(L::FunctionOperator) traits, L.p, L.t, - L.kwargs, cache, ) end @@ -330,7 +328,6 @@ function Base.inv(L::FunctionOperator) traits, L.p, L.t, - L.kwargs, cache, ) end @@ -387,27 +384,27 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin # operator application function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} - L.op(u, L.p, L.t; L.kwargs...) + L.op(u, L.p, L.t; L.traits.kwargs...) end function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} - L.op_inverse(u, L.p, L.t; L.kwargs...) + L.op_inverse(u, L.p, L.t; L.traits.kwargs...) end function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat) _, co = L.cache du = zero(co) - L.op(du, u, L.p, L.t; L.kwargs...) + L.op(du, u, L.p, L.t; L.traits.kwargs...) end function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat) ci, _ = L.cache du = zero(ci) - L.op_inverse(du, u, L.p, L.t; L.kwargs...) + L.op_inverse(du, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op(v, u, L.p, L.t; L.kwargs...) + L.op(v, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...) @@ -424,11 +421,11 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, end function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop} - L.op(v, u, L.p, L.t, α, β; L.kwargs...) + L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...) end function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat) - L.op_inverse(v, u, L.p, L.t; L.kwargs...) + L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...) end function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat) From 3854874316d2baae7030a7173e6dcb76c2efbcb2 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 29 May 2023 13:47:49 -0400 Subject: [PATCH 22/23] tests passing --- test/total.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/total.jl b/test/total.jl index 6dd2f645..5a1e9004 100644 --- a/test/total.jl +++ b/test/total.jl @@ -82,10 +82,10 @@ end A = rand(N,N) # Introduce update function for B - B = MatrixOperator(zeros(N,N); update_func=(A, u, p, t) -> (A .= p)) + B = MatrixOperator(zeros(N,N); update_func! = (A, u, p, t) -> (A .= p)) C = rand(N,N) # Introduce update function for D dependent on kwarg "matrix" - D = MatrixOperator(zeros(N,N); update_func=(A, u, p, t; matrix) -> (A .= p*t*matrix), + D = MatrixOperator(zeros(N,N); update_func! = (A, u, p, t; matrix) -> (A .= p*t*matrix), accepted_kwargs = (:matrix,)) u = rand(N2,K) @@ -99,8 +99,8 @@ end T1 = ⊗(A, B) T2 = ⊗(C, D) - D1 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t) -> p) - D2 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t; diag) -> p*t*diag, + D1 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t) -> d .= p) + D2 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t; diag) -> d .= p*t*diag, accepted_kwargs = (:diag,)) TT = [T1, T2] From 3a665ca9232586c7a3b03be3740aeecc8bcb6806 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 29 May 2023 13:58:29 -0400 Subject: [PATCH 23/23] Base.Pairs notdef in LTS --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 92f66552..b08c18a5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -22,7 +22,7 @@ end # Filter keyword arguments to those accepted by function. # Avoid throwing errors here if a keyword argument is not provided: defer # this to the function call for a more readable error. -function get_filtered_kwargs(kwargs::Base.Pairs, accepted_kwargs::NTuple{N,Symbol}) where{N} +function get_filtered_kwargs(kwargs::AbstractDict, accepted_kwargs::NTuple{N,Symbol}) where{N} (kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw)) end