diff --git a/lib/ModelingToolkitBase/src/systems/callbacks.jl b/lib/ModelingToolkitBase/src/systems/callbacks.jl index e70a62af62..3b0eee36a1 100644 --- a/lib/ModelingToolkitBase/src/systems/callbacks.jl +++ b/lib/ModelingToolkitBase/src/systems/callbacks.jl @@ -217,34 +217,29 @@ unPre(x::Symbolics.Arr) = unPre(unwrap(x)) unPre(x::SymbolicT) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x distribute_shift_into_operator(::Pre) = false -function (p::Pre)(x) - iw = Symbolics.iswrapped(x) - x = unwrap(x) - # non-symbolic values don't change - SU.isconst(x) && return x - if symbolic_type(x) == NotSymbolic() - return x - end - # differential variables are default-toterm-ed - if iscall(x) && operation(x) isa Differential - x = default_toterm(x) - end - # don't double wrap - iscall(x) && operation(x) isa Pre && return x - result = if iscall(x) && operation(x) === getindex - # instead of `Pre(x[1])` create `Pre(x)[1]` - # which allows parameter indexing to handle this case automatically. - arr = arguments(x)[1] - p(arr)[arguments(x)[2:end]...] - else - term(p, x; type = symtype(x), shape = SU.shape(x)) - end - # the result should be a parameter - result = toparam(result) - if iw - result = wrap(result) +(p::Pre)(x::Num) = Num(p(unwrap(x))) +(p::Pre)(x::Symbolics.Arr{T, N}) where {T, N} = Symbolics.Arr{T, N}(p(unwrap(x))) +function (p::Pre)(x::SymbolicT) + iscall(x) || return x + return Moshi.Match.@match x begin + BSImpl.Term(; f) && if f isa Pre end => return x + BSImpl.Term(; f) && if f isa Differential end => begin + return p(default_toterm(x)) + end + BSImpl.Term(; f, args, type, shape) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = p(args[1]) + return toparam(BSImpl.Term{VartypeT}(f, newargs; type, shape)) + end + BSImpl.Term(; f, type, shape) && if f isa SymbolicT && !SU.is_function_symbolic(f) end => begin + return toparam(BSImpl.Term{VartypeT}(p, SArgsT((x,)); type, shape)) + end + _ => begin + op = operation(x) + args = map(p, arguments(x)) + return toparam(maketerm(SymbolicT, op, args, nothing; type = symtype(x))) + end end - return result end haspre(eq::Equation) = haspre(eq.lhs) || haspre(eq.rhs) haspre(O) = recursive_hasoperator(Pre, O) diff --git a/lib/ModelingToolkitBase/test/symbolic_events.jl b/lib/ModelingToolkitBase/test/symbolic_events.jl index 940bc183bb..124263d4d1 100644 --- a/lib/ModelingToolkitBase/test/symbolic_events.jl +++ b/lib/ModelingToolkitBase/test/symbolic_events.jl @@ -1824,4 +1824,11 @@ if !@isdefined(ModelingToolkit) @mtkcompile sys = System(eqs, t, [X], [p, Kᵢ, Kₐ, K]; discrete_events) @test_nowarn ODEProblem(sys, [X => 1, p => 1, Kᵢ => 1, Kₐ => 2], (0.0, 1.0)) end + + @testset "Issue:4095: `Pre` recurses into expressions" begin + @variables x(t) + @parameters p (f::Function)(..) + @discretes d(t) + @test isequal(Pre(2x^2 + 3sin(f(x)) - ifelse(p < 0, d, d + 2) + 2p), 2Pre(x)^2 + 3sin(f(Pre(x))) - ifelse(p < 0, Pre(d), Pre(d) + 2) + 2p) + end end