diff --git a/src/basic.jl b/src/basic.jl index cc93b92e..961f1d08 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -176,6 +176,7 @@ end ScaledOperator (λ L)*(u) = λ * L(u) + """ struct ScaledOperator{T, λType, @@ -538,15 +539,28 @@ end #Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u) #Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u) -function Base.:\(L::ComposedOperator, u::AbstractVecOrMat) +function (L::ComposedOperator)(u, p, t) v = u - for op in L.ops - v = op \ v + for op in reverse(L.ops) + update_coefficients!(op, v, p, t) + v = op * v end v end +function (L::ComposedOperator)(v, u, p, t) + @assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)). + set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)" + + vecs = (v, L.cache[1:end-1]..., u) + for i in reverse(1:length(L.ops)) + update_coefficients!(L.ops[i], vecs[i+1], p, t) + mul!(vecs[i], L.ops[i], vecs[i+1]) + end + v +end + function Base.:*(L::ComposedOperator, u::AbstractVecOrMat) v = u for op in reverse(L.ops) @@ -556,6 +570,15 @@ function Base.:*(L::ComposedOperator, u::AbstractVecOrMat) v end +function Base.:\(L::ComposedOperator, u::AbstractVecOrMat) + v = u + for op in L.ops + v = op \ v + end + + v +end + function cache_self(L::ComposedOperator, u::AbstractVecOrMat) if has_mul(L) vec = zero(u) diff --git a/test/basic.jl b/test/basic.jl index 5b29eb53..7b0d1052 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -235,6 +235,60 @@ end @test ldiv!(rand(N), op, u) ≈ op \ u end +@testset "ComposedOperator nonlinear operator composition test" begin + u = rand(N) + p = nothing + t = 0.0 + + square(u) = u .^ 2 + square(u, p, t) = u .^ 2 + square(v, u, p, t) = v .= u .* u + + root(u) = u .^ 2 + root(u, p, t) = u .^ 2 + root(v, u, p, t) = v .= u .* u + + F = FunctionOperator(square, u; islinear = false, op_inverse = root) + + A = DiagonalOperator(zeros(N); update_func = (d, u, p, t) -> copy!(d, u)) # u .^2 + B = DiagonalOperator(zeros(N); update_func = (d, u, p, t) -> copy!(d, u)) + C = DiagonalOperator(zeros(N); update_func = (d, u, p, t) -> copy!(d, u)) + + L = A ∘ B ∘ C + F3 = F ∘ F ∘ F + + sq = u |> square |> square |> square + + @test A(B(C(u, p, t), p, t), p, t) ≈ sq + @test L(u, p, t) ≈ sq + @test F3(u, p, t) ≈ sq + + L = cache_operator(L, u) + v = rand(N); @test L(v, u, p, t) ≈ sq + + Fi = inv(F) + F3i = inv(F3) + + rt = u |> root |> root |> root + @test F3i(u, p, t) ≈ rt + + Ai = inv(A) + Bi = inv(B) + Ci = inv(C) + + Li = inv(L) + Fi = inv(F) + for op in (Ai, Bi, Ci, Li) + @test op isa SciMLOperators.InvertedOperator + end + + rt = Ai(Bi(Ci(u, p, t), p, t), p, t) + @test Ai(u, p, t) ≈ ones(N) + # TODO - overwrite L(u, p, t) for InvertedOperator + @test_broken Li(u, p, t) ≈ ones(N) + v = rand(N); @test_broken Li(v, u, p, t) ≈ ones(N) +end + @testset "Adjoint, Transpose" begin for (op,