Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1ed74b6
Recursively propagate kwargs through update_coefficients!
gaurav-arya Jan 30, 2023
293d5eb
Rename accepted_kwarg_fields -> accepted_kwargs
gaurav-arya Mar 12, 2023
1712594
Allow accepted_kwargs=nothing to indicate no wrapping
gaurav-arya Mar 12, 2023
cd70503
Tweak keyword filtering logic
gaurav-arya Mar 12, 2023
17298fb
Test operator update (including kwarg update) in operator algebra test
gaurav-arya Mar 12, 2023
dc3e31d
Support kwargs in function operator
gaurav-arya Mar 12, 2023
5b47ef9
Propagate kwargs for out-of-place function operator update_coefficien…
gaurav-arya Mar 12, 2023
88d5050
Catch function operator error for empty kwargs
gaurav-arya Mar 12, 2023
c7fcd51
Address code review suggestions on diag op construction
gaurav-arya Mar 12, 2023
24aef00
Improve logic for normalizing kwargs
gaurav-arya Mar 12, 2023
44d66bb
Test operator application form in operator algebra test set
gaurav-arya Mar 12, 2023
9d8fdff
Support kwargs in function operator functionals
gaurav-arya Mar 12, 2023
f793c60
Add example
gaurav-arya Mar 12, 2023
3312967
Remove unncessary function call
gaurav-arya Mar 12, 2023
69f0ecc
Rename kwargs_for_op -> accepted_kwargs
gaurav-arya Mar 12, 2023
fef7618
Fix function operator out-of-place update coefficients
gaurav-arya Mar 12, 2023
ceca67c
Use NoKwargFilter() to bypass keyword filtering (rather than nothing)
gaurav-arya Mar 12, 2023
0efa2ce
Remove debug line
gaurav-arya Mar 12, 2023
2adb43f
Merge branch 'master' into ag-kwargs
vpuri3 May 27, 2023
e63e445
fix diagonaloperator update
vpuri3 May 27, 2023
e89d5a1
function op working
vpuri3 May 29, 2023
4818a66
moved kwargs to FunctionOp.traits
vpuri3 May 29, 2023
3854874
tests passing
vpuri3 May 29, 2023
3a665ca
Base.Pairs notdef in LTS
vpuri3 May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
are still compatible with essentially any Krylov method which would otherwise be compatible with
matrix-free representations, hence their support in the SciMLOperators interface.

## Note about keyword arguments to `update_coefficients!`

In rare cases, an operator may be used in a context where additional state is expected to be provided
to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional
state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator.
For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument that defaults to an empty tuple.
22 changes: 12 additions & 10 deletions src/batch.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
#
"""
BatchedDiagonalOperator(diag, [; update_func])
BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwargs=())

Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
by `update_coefficients!` and is assumed to have the following signature:

update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractVector,u,p,t; <accepted kwarg fields>) -> [modifies diag]
"""
struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T}
diag::D
update_func::F

function BatchedDiagonalOperator(
diag::AbstractArray;
update_func=DEFAULT_UPDATE_FUNC
update_func=nothing,
accepted_kwargs=()
)
_update_func = preprocess_update_func(update_func, accepted_kwargs)
new{
eltype(diag),
typeof(diag),
typeof(update_func)
typeof(_update_func)
}(
diag, update_func,
diag, _update_func,
)
end
end

function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC)
BatchedDiagonalOperator(u; update_func=update_func)
function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwargs=())
BatchedDiagonalOperator(u; update_func, accepted_kwargs)
end

# traits
Expand All @@ -40,7 +42,7 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
update_func = if isreal(L)
L.update_func
else
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
end
BatchedDiagonalOperator(diag; update_func=update_func)
end
Expand All @@ -57,15 +59,15 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
isconstant(L::BatchedDiagonalOperator) = update_func_isconstant(L.update_func)
islinear(::BatchedDiagonalOperator) = true
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)

getops(L::BatchedDiagonalOperator) = (L.diag,)

update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing)
update_coefficients!(L::BatchedDiagonalOperator,u,p,t; kwargs...) = (L.update_func(L.diag,u,p,t; kwargs...); nothing)

# operator application
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
Expand Down
38 changes: 24 additions & 14 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Matrix free operators (given by a function)
"""
mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: AbstractSciMLOperator{T}
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
op::F
""" Adjoint operator"""
Expand All @@ -17,6 +17,8 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
p::P
""" Time """
t::Tt
""" Keyword arguments """
kwargs::K
""" Cache """
cache::C

Expand All @@ -28,6 +30,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
traits,
p,
t,
kwargs_for_op,
cache
)

Expand All @@ -46,6 +49,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
typeof(traits),
typeof(p),
typeof(t),
typeof(kwargs_for_op),
typeof(cache),
}(
op,
Expand All @@ -55,6 +59,7 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
traits,
p,
t,
kwargs_for_op,
cache,
)
end
Expand Down Expand Up @@ -82,6 +87,7 @@ function FunctionOperator(op,
FunctionOperator(op, input, output; kwargs...)
end

# TODO: document constructor and revisit design as needed (e.g. for "kwargs_for_op")
function FunctionOperator(op,
input::AbstractVecOrMat,
output::AbstractVecOrMat = input;
Expand All @@ -96,6 +102,7 @@ function FunctionOperator(op,

p=nothing,
t::Union{Number,Nothing}=nothing,
kwargs_for_op=(),

ifcache::Bool = true,

Expand Down Expand Up @@ -169,17 +176,18 @@ function FunctionOperator(op,
traits,
p,
t,
cache,
# automatically convert NamedTuple's to pairs
pairs(kwargs_for_op),
cache
)

ifcache ? cache_operator(L, input, output) : L
end

function update_coefficients(L::FunctionOperator, u, p, t)
op = update_coefficients(L.op, u, p, t)
op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
op_inverse = update_coefficients(L.op_inverse, u, p, t)
op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t)
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
for op in getops(L)
op = update_coefficients(op, u, p, t; kwargs...)
end

FunctionOperator(op,
op_adjoint,
Expand All @@ -188,18 +196,20 @@ function update_coefficients(L::FunctionOperator, u, p, t)
L.traits,
p,
t,
kwargs_for_op=kwargs,
L.cache
)
end

function update_coefficients!(L::FunctionOperator, u, p, t)
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
ops = getops(L)
for op in ops
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end

L.p = p
L.t = t
L.kwargs = kwargs

nothing
end
Expand Down Expand Up @@ -326,22 +336,22 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin

# operator application
Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t)
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t)
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t; L.kwargs...)

function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
_, co = L.cache
du = zero(co)
L.op(du, u, L.p, L.t)
L.op(du, u, L.p, L.t; L.kwargs...)
end

function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
ci, _ = L.cache
du = zero(ci)
L.op_inverse(du, u, L.p, L.t)
L.op_inverse(du, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
L.op(v, u, L.p, L.t)
L.op(v, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...)
Expand All @@ -358,7 +368,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::A
end

function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
L.op_inverse(v, u, L.p, L.t)
L.op_inverse(v, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat)
Expand Down
26 changes: 20 additions & 6 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,33 @@ out-of-place form B = update_coefficients(A,u,p,t).
"""
function (::AbstractSciMLOperator) end

# Utilities for update functions
DEFAULT_UPDATE_FUNC(A,u,p,t) = A
function preprocess_update_func(update_func, accepted_kwargs)
update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func
# accepted_kwargs can be passed as nothing to indicate that we should not filter
# (e.g. if the function already accepts all kwargs...).
return (accepted_kwargs === nothing) ? update_func : FilterKwargs(update_func, accepted_kwargs)
end
function update_func_isconstant(update_func)
if update_func isa FilterKwargs
return update_func.f == DEFAULT_UPDATE_FUNC
else
return update_func == DEFAULT_UPDATE_FUNC
end
end

update_coefficients!(L,u,p,t) = nothing
update_coefficients(L,u,p,t) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
update_coefficients!(L,u,p,t; kwargs...) = nothing
update_coefficients(L,u,p,t; kwargs...) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...)
for op in getops(L)
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end
nothing
end

(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
(L::AbstractSciMLOperator)(u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); L * u)
(L::AbstractSciMLOperator)(du, u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u))

###
# caching interface
Expand Down
Loading