diff --git a/Project.toml b/Project.toml index e6fd39d9e4..b412c502b6 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" @@ -81,6 +82,7 @@ DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6, 0.7" DynamicQuantities = "^0.11.2, 0.12, 0.13" ExprTools = "0.1.10" +Expronicon = "0.8" FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" @@ -98,10 +100,10 @@ NonlinearSolve = "3.12" OrderedCollections = "1" OrdinaryDiffEq = "6.82.0" PrecompileTools = "1" -RecursiveArrayTools = "2.3, 3" +RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.28.0" +SciMLBase = "2.46" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -109,7 +111,7 @@ SimpleNonlinearSolve = "0.1.0, 1" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -SymbolicIndexingInterface = "0.3.12" +SymbolicIndexingInterface = "0.3.26" SymbolicUtils = "2.1" Symbolics = "5.32" URIs = "1" diff --git a/docs/src/tutorials/SampledData.md b/docs/src/tutorials/SampledData.md index 614e8b65c7..a72fd1698b 100644 --- a/docs/src/tutorials/SampledData.md +++ b/docs/src/tutorials/SampledData.md @@ -16,7 +16,7 @@ A clock can be seen as an *event source*, i.e., when the clock ticks, an event i - [`Hold`](@ref) - [`ShiftIndex`](@ref) -When a continuous-time variable `x` is sampled using `xd = Sample(x, dt)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. +When a continuous-time variable `x` is sampled using `xd = Sample(dt)(x)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. To make a discrete-time variable available to the continuous partition, the [`Hold`](@ref) operator is used. `xc = Hold(xd)` creates a continuous-time variable `xc` that is updated whenever the clock associated with `xd` ticks, and holds its value constant between ticks. @@ -34,7 +34,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t @variables x(t) y(t) u(t) dt = 0.1 # Sample interval -clock = Clock(t, dt) # A periodic clock with tick rate dt +clock = Clock(dt) # A periodic clock with tick rate dt k = ShiftIndex(clock) eqs = [ @@ -99,7 +99,7 @@ may thus be modeled as ```julia t = ModelingToolkit.t_nounits @variables y(t) [description = "Output"] u(t) [description = "Input"] -k = ShiftIndex(Clock(t, dt)) +k = ShiftIndex(Clock(dt)) eqs = [ a2 * y(k) + a1 * y(k - 1) + a0 * y(k - 2) ~ b2 * u(k) + b1 * u(k - 1) + b0 * u(k - 2) ] @@ -128,10 +128,10 @@ requires specification of the initial condition for both `x(k-1)` and `x(k-2)`. Multi-rate systems are easy to model using multiple different clocks. The following set of equations is valid, and defines *two different discrete-time partitions*, each with its own clock: ```julia -yd1 ~ Sample(t, dt1)(y) -ud1 ~ kp * (Sample(t, dt1)(r) - yd1) -yd2 ~ Sample(t, dt2)(y) -ud2 ~ kp * (Sample(t, dt2)(r) - yd2) +yd1 ~ Sample(dt1)(y) +ud1 ~ kp * (Sample(dt1)(r) - yd1) +yd2 ~ Sample(dt2)(y) +ud2 ~ kp * (Sample(dt2)(r) - yd2) ``` `yd1` and `ud1` belong to the same clock which ticks with an interval of `dt1`, while `yd2` and `ud2` belong to a different clock which ticks with an interval of `dt2`. The two clocks are *not synchronized*, i.e., they are not *guaranteed* to tick at the same point in time, even if one tick interval is a rational multiple of the other. Mechanisms for synchronization of clocks are not yet implemented. @@ -148,7 +148,7 @@ using ModelingToolkit: t_nounits as t using ModelingToolkit: D_nounits as D dt = 0.5 # Sample interval @variables r(t) -clock = Clock(t, dt) +clock = Clock(dt) k = ShiftIndex(clock) function plant(; name) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index e70991ad3e..211c184130 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -43,7 +43,8 @@ using SciMLStructures using Compat using AbstractTrees using DiffEqBase, SciMLBase, ForwardDiff -using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap +using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain, + PeriodicClock, Clock, SolverStepClock, Continuous using Distributed import JuliaFormatter using MLStyle @@ -272,6 +273,6 @@ export debug_system #export has_discrete_domain, has_continuous_domain #export is_discrete_domain, is_continuous_domain, is_hybrid_domain export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime -export Clock #, InferredDiscrete, +export Clock, SolverStepClock, TimeDomain end # module diff --git a/src/clock.jl b/src/clock.jl index 5df6cfb022..26ea5832da 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,13 +1,26 @@ -abstract type TimeDomain end -abstract type AbstractDiscrete <: TimeDomain end +module InferredClock -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) +export InferredTimeDomain -struct Inferred <: TimeDomain end -struct InferredDiscrete <: AbstractDiscrete end -struct Continuous <: TimeDomain end +using Expronicon.ADT: @adt, @match +using SciMLBase: TimeDomain -Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain +@adt InferredTimeDomain begin + Inferred + InferredDiscrete +end + +Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) + +end + +using .InferredClock + +struct VariableTimeDomain end +Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain + +is_concrete_time_domain(::TimeDomain) = true +is_concrete_time_domain(_) = false """ is_continuous_domain(x) @@ -16,7 +29,7 @@ true if `x` contains only continuous-domain signals. See also [`has_continuous_domain`](@ref) """ function is_continuous_domain(x) - issym(x) && return getmetadata(x, TimeDomain, false) isa Continuous + issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous !has_discrete_domain(x) && has_continuous_domain(x) end @@ -24,7 +37,7 @@ function get_time_domain(x) if iscall(x) && operation(x) isa Operator output_timedomain(x) else - getmetadata(x, TimeDomain, nothing) + getmetadata(x, VariableTimeDomain, nothing) end end get_time_domain(x::Num) = get_time_domain(value(x)) @@ -37,14 +50,14 @@ Determine if variable `x` has a time-domain attributed to it. function has_time_domain(x::Symbolic) # getmetadata(x, Continuous, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing - getmetadata(x, TimeDomain, nothing) !== nothing + getmetadata(x, VariableTimeDomain, nothing) !== nothing end has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = Continuous() - @eval output_timedomain(::$op, arg = nothing) = Continuous() + @eval input_timedomain(::$op, arg = nothing) = Continuous + @eval output_timedomain(::$op, arg = nothing) = Continuous end """ @@ -83,12 +96,17 @@ true if `x` contains only discrete-domain signals. See also [`has_discrete_domain`](@ref) """ function is_discrete_domain(x) - if hasmetadata(x, TimeDomain) || issym(x) - return getmetadata(x, TimeDomain, false) isa AbstractDiscrete + if hasmetadata(x, VariableTimeDomain) || issym(x) + return is_discrete_time_domain(getmetadata(x, VariableTimeDomain, false)) end !has_discrete_domain(x) && has_continuous_domain(x) end +sampletime(c) = @match c begin + PeriodicClock(dt, _...) => dt + _ => nothing +end + struct ClockInferenceException <: Exception msg::Any end @@ -97,57 +115,4 @@ function Base.showerror(io::IO, cie::ClockInferenceException) print(io, "ClockInferenceException: ", cie.msg) end -abstract type AbstractClock <: AbstractDiscrete end - -""" - Clock <: AbstractClock - Clock([t]; dt) - -The default periodic clock with independent variables `t` and tick interval `dt`. -If `dt` is left unspecified, it will be inferred (if possible). -""" -struct Clock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - dt::Union{Nothing, Float64} - Clock(t::Union{Num, Symbolic}, dt = nothing) = new(value(t), dt) - Clock(t::Nothing, dt = nothing) = new(t, dt) -end -Clock(dt::Real) = Clock(nothing, dt) -Clock() = Clock(nothing, nothing) - -sampletime(c) = isdefined(c, :dt) ? c.dt : nothing -Base.hash(c::Clock, seed::UInt) = hash(c.dt, seed ⊻ 0x953d7a9a18874b90) -function Base.:(==)(c1::Clock, c2::Clock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) && c1.dt == c2.dt -end - -is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous} - -""" - SolverStepClock <: AbstractClock - SolverStepClock() - SolverStepClock(t) - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). This clock **does generally not have equidistant tick intervals**, instead, the tick interval depends on the adaptive step-size selection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with discrete-time systems that assume a fixed sample time, such as PID controllers and digital filters. -""" -struct SolverStepClock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - SolverStepClock(t::Union{Num, Symbolic}) = new(value(t)) -end -SolverStepClock() = SolverStepClock(nothing) - -Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91 -function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) -end - -struct IntegerSequence <: AbstractClock - t::Union{Nothing, Symbolic} - IntegerSequence(t::Union{Num, Symbolic}) = new(value(t)) -end +struct IntegerSequence end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index cb723e159f..facb151d77 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -85,8 +85,8 @@ $(TYPEDEF) Represents a sample operator. A discrete-time signal is created by sampling a continuous-time signal. # Constructors -`Sample(clock::TimeDomain = InferredDiscrete())` -`Sample([t], dt::Real)` +`Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete)` +`Sample(dt::Real)` `Sample(x::Num)`, with a single argument, is shorthand for `Sample()(x)`. @@ -100,16 +100,23 @@ julia> using Symbolics julia> t = ModelingToolkit.t_nounits -julia> Δ = Sample(t, 0.01) +julia> Δ = Sample(0.01) (::Sample) (generic function with 2 methods) ``` """ struct Sample <: Operator clock::Any - Sample(clock::TimeDomain = InferredDiscrete()) = new(clock) - Sample(t, dt::Real) = new(Clock(t, dt)) + Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete) = new(clock) +end + +function Sample(arg::Real) + arg = unwrap(arg) + if symbolic_type(arg) == NotSymbolic() + Sample(Clock(arg)) + else + Sample()(arg) + end end -Sample(x) = Sample()(x) (D::Sample)(x) = Term{symtype(x)}(D, Any[x]) (D::Sample)(x::Num) = Num(D(value(x))) SymbolicUtils.promote_symtype(::Sample, x) = x @@ -176,15 +183,18 @@ julia> x(k) # no shift x(t) julia> x(k+1) # shift -Shift(t, 1)(x(t)) +Shift(1)(x(t)) ``` """ struct ShiftIndex - clock::TimeDomain + clock::Union{InferredTimeDomain, TimeDomain, IntegerSequence} steps::Int - ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps) - ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps) - ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps) + function ShiftIndex( + clock::Union{TimeDomain, InferredTimeDomain, IntegerSequence} = Inferred, steps::Int = 0) + new(clock, steps) + end + ShiftIndex(dt::Real, steps::Int = 0) = new(Clock(dt), steps) + ShiftIndex(::Num, steps::Int) = new(IntegerSequence(), steps) end function (xn::Num)(k::ShiftIndex) @@ -197,18 +207,13 @@ function (xn::Num)(k::ShiftIndex) args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t) length(args) == 1 || error("Cannot shift an expression with multiple independent variables $x.") - t = args[] - if hasfield(typeof(clock), :t) - isequal(t, clock.t) || - error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)") - end # d, _ = propagate_time_domain(xn) # if d != clock # this is only required if the variable has another clock # xn = Sample(t, clock)(xn) # end # QUESTION: should we return a variable with time domain set to k.clock? - xn = setmetadata(xn, TimeDomain, k.clock) + xn = setmetadata(xn, VariableTimeDomain, k.clock) if steps == 0 return xn # x(k) needs no shift operator if the step of k is 0 end @@ -221,37 +226,37 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) """ input_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` operates on. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end """ output_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` results in. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in. """ function output_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end -input_timedomain(::Sample, arg = nothing) = Continuous() +input_timedomain(::Sample, arg = nothing) = Continuous output_timedomain(s::Sample, arg = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + InferredDiscrete # the Hold accepts any discrete end -output_timedomain(::Hold, arg = nothing) = Continuous() +output_timedomain(::Hold, arg = nothing) = Continuous sampletime(op::Sample, arg = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, arg = nothing) = sampletime(op.clock) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 97d5dfc970..11292752cc 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -446,8 +446,9 @@ end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return is_parameter(ic, sym) || - iscall(sym) && operation(sym) === getindex && + return sym isa ParameterIndex || is_parameter(ic, sym) || + iscall(sym) && + operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end if unwrap(sym) isa Int @@ -461,21 +462,34 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return is_parameter(ic, sym) end - return any(isequal(sym), getname.(parameter_symbols(sys))) || + + named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)] + return any(isequal(sym), named_parameters) || count(NAMESPACE_SEPARATOR, string(sym)) == 1 && count(isequal(sym), - Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys)))) == - 1 + Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, named_parameters)) == 1 end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return if (idx = parameter_index(ic, sym)) !== nothing - idx + return if sym isa ParameterIndex + sym + elseif (idx = parameter_index(ic, sym)) !== nothing + if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + return nothing + else + idx + end elseif iscall(sym) && operation(sym) === getindex && (idx = parameter_index(ic, first(arguments(sym)))) !== nothing - ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + if idx.portion isa SciMLStructures.Discrete && + idx.idx[2] == idx.idx[3] == nothing + return nothing + else + ParameterIndex( + idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + end else nothing end @@ -493,7 +507,13 @@ end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return parameter_index(ic, sym) + idx = parameter_index(ic, sym) + if idx === nothing || + idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + return nothing + else + return idx + end end idx = findfirst(isequal(sym), getname.(parameter_symbols(sys))) if idx !== nothing @@ -506,6 +526,111 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym return nothing end +function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym) + is_time_dependent(sys) || return false + has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false + is_timeseries_parameter(ic, sym) +end + +function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym) + is_time_dependent(sys) || return nothing + has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing + timeseries_parameter_index(ic, sym) +end + +function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + rawobs = build_explicit_observed_function( + sys, sym; param_only = true, return_inplace = true) + if rawobs isa Tuple + if is_time_dependent(sys) + obsfn = let oop = rawobs[1], iip = rawobs[2] + f1a(p::MTKParameters, t) = oop(p..., t) + f1a(out, p::MTKParameters, t) = iip(out, p..., t) + end + else + obsfn = let oop = rawobs[1], iip = rawobs[2] + f1b(p::MTKParameters) = oop(p...) + f1b(out, p::MTKParameters) = iip(out, p...) + end + end + else + if is_time_dependent(sys) + obsfn = let rawobs = rawobs + f2a(p::MTKParameters, t) = rawobs(p..., t) + end + else + obsfn = let rawobs = rawobs + f2b(p::MTKParameters) = rawobs(p...) + end + end + end + else + obsfn = build_explicit_observed_function(sys, sym; param_only = true) + end + return obsfn +end + +function has_observed_with_lhs(sys, sym) + has_observed(sys) || return false + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return any(isequal(sym), ic.observed_syms) + else + return any(isequal(sym), [eq.lhs for eq in observed(sys)]) + end +end + +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym) + if is_variable(sys, sym) || is_independent_variable(sys, sym) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx) + end +end +# Need this to avoid ambiguity with the array case +for traitT in [ + ScalarSymbolic, + ArraySymbolic +] + @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) + allsyms = vars(sym; op = Symbolics.Operator) + for s in allsyms + s = unwrap(s) + if is_variable(sys, s) || is_independent_variable(sys, s) || + has_observed_with_lhs(sys, s) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, s) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end + end + end +end +function _all_ts_idxs!(ts_idxs, ::ScalarSymbolic, sys, sym::Symbol) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return _all_ts_idxs!(ts_idxs, sys, ic.symbol_to_variable[sym]) + elseif is_variable(sys, sym) || is_independent_variable(sys, sym) || + any(isequal(sym), [getname(eq.lhs) for eq in observed(sys)]) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end +end +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) + for s in sym + _all_ts_idxs!(ts_idxs, sys, s) + end +end +_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, symbolic_type(sym), sys, sym) + +function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym) + if !is_time_dependent(sys) + return Set() + end + ts_idxs = Set() + _all_ts_idxs!(ts_idxs, sys, sym) + return ts_idxs +end + function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) return full_parameters(sys) end @@ -523,7 +648,7 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys end function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym) - return !is_variable(sys, sym) && !is_parameter(sys, sym) && + return !is_variable(sys, sym) && parameter_index(sys, sym) === nothing && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic() end @@ -562,6 +687,8 @@ function SymbolicIndexingInterface.observed( return let _fn = _fn fn2(u, p) = _fn(u, p) fn2(u, p::MTKParameters) = _fn(u, p...) + fn2(::Nothing, p) = _fn([], p) + fn2(::Nothing, p::MTKParameters) = _fn([], p...) fn2 end end diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index dc1d612d73..dfdef69034 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -8,8 +8,8 @@ end function ClockInference(ts::TransformationState) @unpack structure = ts @unpack graph = structure - eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)] - var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)] + eq_domain = TimeDomain[Continuous for _ in 1:nsrcs(graph)] + var_domain = TimeDomain[Continuous for _ in 1:ndsts(graph)] inferred = BitSet() for (i, v) in enumerate(get_fullvars(ts)) d = get_time_domain(v) @@ -151,7 +151,7 @@ function split_system(ci::ClockInference{S}) where {S} get!(clock_to_id, d) do cid = (cid_counter[] += 1) push!(id_to_clock, d) - if d isa Continuous + if d == Continuous continuous_id[] = cid end cid @@ -186,6 +186,13 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + if continuous_id != 0 + tss[continuous_id], tss[end] = tss[end], tss[continuous_id] + inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] + id_to_clock[continuous_id], id_to_clock[end] = id_to_clock[end], + id_to_clock[continuous_id] + continuous_id = lastindex(tss) + end return tss, inputs, continuous_id, id_to_clock end @@ -196,19 +203,14 @@ function generate_discrete_affect( @static if VERSION < v"1.7" error("The `generate_discrete_affect` function requires at least Julia 1.7") end - use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing + has_index_cache(osys) && get_index_cache(osys) !== nothing || + error("Hybrid systems require `split = true`") out = Sym{Any}(:out) appended_parameters = full_parameters(syss[continuous_id]) offset = length(appended_parameters) - param_to_idx = if use_index_cache - Dict{Any, ParameterIndex}(p => parameter_index(osys, p) - for p in appended_parameters) - else - Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters)) - end + param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p) + for p in appended_parameters) affect_funs = [] - init_funs = [] - svs = [] clocks = TimeDomain[] for (i, (sys, input)) in enumerate(zip(syss, inputs)) i == continuous_id && continue @@ -224,11 +226,7 @@ function generate_discrete_affect( push!(fullvars, s) end needed_disc_to_cont_obs = [] - if use_index_cache - disc_to_cont_idxs = ParameterIndex[] - else - disc_to_cont_idxs = Int[] - end + disc_to_cont_idxs = ParameterIndex[] for v in inputs[continuous_id] _v = arguments(v)[1] if _v in fullvars @@ -248,7 +246,7 @@ function generate_discrete_affect( end append!(appended_parameters, input) cont_to_disc_obs = build_explicit_observed_function( - use_index_cache ? osys : syss[continuous_id], + osys, needed_cont_to_disc_obs, throw = false, expression = true, @@ -270,76 +268,20 @@ function generate_discrete_affect( ], [], let_block) |> toexpr - if use_index_cache - cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] - disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] - else - cont_to_disc_idxs = (offset + 1):(offset += ni) - input_offset = offset - disc_range = (offset + 1):(offset += ns) - end - save_vec = Expr(:ref, :Float64) - if use_index_cache - for unk in unknowns(sys) - idx = parameter_index(osys, unk) - push!(save_vec.args, :($(parameter_values)(p, $idx))) - end - else - for i in 1:ns - push!(save_vec.args, :(p[$(input_offset + i)])) - end - end + cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] + disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] + save_expr = :($(SciMLBase.save_discretes!)(integrator, $i)) empty_disc = isempty(disc_range) - disc_init = if use_index_cache - :(function (u, p, t) - c2d_obs = $cont_to_disc_obs - d2c_obs = $disc_to_cont_obs - result = c2d_obs(u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - - disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) - result = d2c_obs(disc_state, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - # prevent multiple updates to dependents - _set_parameter_unchecked!(p, val, i; update_dependent = false) - end - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) # to force recalculation of dependents - end) - else - :(function (u, p, t) - c2d_obs = $cont_to_disc_obs - d2c_obs = $disc_to_cont_obs - c2d_view = view(p, $cont_to_disc_idxs) - d2c_view = view(p, $disc_to_cont_idxs) - disc_unknowns = view(p, $disc_range) - copyto!(c2d_view, c2d_obs(u, p, t)) - copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)) - end) - end # @show disc_to_cont_idxs # @show cont_to_disc_idxs # @show disc_range - affect! = :(function (integrator, saved_values) + affect! = :(function (integrator) @unpack u, p, t = integrator c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs - $( - if use_index_cache - :(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]) - else - quote - c2d_view = view(p, $cont_to_disc_idxs) - d2c_view = view(p, $disc_to_cont_idxs) - disc_unknowns = view(p, $disc_range) - end - end - ) # TODO: find a way to do this without allocating + disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range] disc = $disc # Write continuous into to discrete: handles `Sample` @@ -351,79 +293,43 @@ function generate_discrete_affect( # d2c comes last # @show t # @show "incoming", p - $( - if use_index_cache + result = c2d_obs(u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + $(if !empty_disc quote - result = c2d_obs(integrator.u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) + disc(disc_unknowns, u, p..., t) + for (val, i) in zip(disc_unknowns, $disc_range) $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end end - else - :(copyto!(c2d_view, c2d_obs(integrator.u, p, t))) - end - ) + end) # @show "after c2d", p - $( - if use_index_cache - quote - if !$empty_disc - disc(disc_unknowns, integrator.u, p..., t) - for (val, i) in zip(disc_unknowns, $disc_range) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - end - end - else - :($empty_disc || disc(disc_unknowns, disc_unknowns, p, t)) - end - ) # @show "after state update", p - $( - if use_index_cache - quote - result = d2c_obs(disc_unknowns, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - end - else - :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) + result = d2c_obs(disc_unknowns, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end - ) - push!(saved_values.t, t) - push!(saved_values.saveval, $save_vec) + $save_expr # @show "after d2c", p - $( - if use_index_cache - quote - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) - end - end - ) + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) + repack(discretes) end) - sv = SavedValues(Float64, Vector{Float64}) + push!(affect_funs, affect!) - push!(init_funs, disc_init) - push!(svs, sv) end if eval_expression affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs) - inits = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), init_funs) else affects = map(affect_funs) do a drop_expr(RuntimeGeneratedFunction( eval_module, eval_module, toexpr(LiteralExpr(a)))) end - inits = map(init_funs) do a - drop_expr(RuntimeGeneratedFunction( - eval_module, eval_module, toexpr(LiteralExpr(a)))) - end end defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs)) - return affects, inits, clocks, svs, appended_parameters, defaults + return affects, clocks, appended_parameters, defaults end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 11abadad5f..5f69266f7e 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -825,7 +825,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first if sys isa ODESystem && build_initializeprob && (((implicit_dae || !isempty(missingvars)) && - all(isequal(Continuous()), ci.var_domain) && + all(==(Continuous), ci.var_domain) && ModelingToolkit.get_tearing_state(sys) !== nothing) || !isempty(initialization_equations(sys))) && t !== nothing if eltype(u0map) <: Number @@ -1008,18 +1008,15 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) inits = [] if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - elseif clock isa SolverStepClock - affect = DiscreteSaveAffect(affect, sv) - DiscreteCallback(Returns(true), affect, + &SolverStepClock => DiscreteCallback(Returns(true), affect, initialize = (c, u, t, integrator) -> affect(integrator)) - else - error("$clock is not a supported clock type.") + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1031,8 +1028,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = else cbs = CallbackSet(cbs, discrete_cbs...) end - else - svs = nothing end kwargs = filter_kwargs(kwargs) pt = something(get_metadata(sys), StandardODEProblem()) @@ -1041,17 +1036,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if cbs !== nothing kwargs1 = merge(kwargs1, (callback = cbs,)) end - if svs !== nothing - kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) - end - prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) - if !isempty(inits) - for init in inits - # init(prob.u0, prob.p, tspan[1]) - end - end - prob + return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] @@ -1124,14 +1110,15 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1186,14 +1173,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 5d1bae95ec..e28f1ece3b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -387,6 +387,7 @@ function build_explicit_observed_function(sys, ts; drop_expr = drop_expr, ps = full_parameters(sys), return_inplace = false, + param_only = false, op = Operator, throw = true) if (isscalar = symbolic_type(ts) !== NotSymbolic()) @@ -399,7 +400,16 @@ function build_explicit_observed_function(sys, ts; ivs = independent_variables(sys) dep_vars = scalarize(setdiff(vars, ivs)) - obs = observed(sys) + obs = param_only ? Equation[] : observed(sys) + if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing + # each subsystem is topologically sorted independently. We can append the + # equations to override the `lhs ~ 0` equations in `observed(sys)` + syss, _, continuous_id, _... = dss + for (i, subsys) in enumerate(syss) + i == continuous_id && continue + append!(obs, observed(subsys)) + end + end cs = collect_constants(obs) if !isempty(cs) > 0 @@ -407,8 +417,9 @@ function build_explicit_observed_function(sys, ts; obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs) end - sts = Set(unknowns(sys)) - sts = union(sts, + sts = param_only ? Set() : Set(unknowns(sys)) + sts = param_only ? Set() : + union(sts, Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex)) observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs)) @@ -420,7 +431,8 @@ function build_explicit_observed_function(sys, ts; Set(arguments(p)[1] for p in param_set_ns if iscall(p) && operation(p) === getindex)) namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs) - namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys)) + namespaced_to_sts = param_only ? Dict() : + Dict(unknowns(sys, x) => x for x in unknowns(sys)) # FIXME: This is a rather rough estimate of dependencies. We assume # the expression depends on everything before the `maxidx`. @@ -485,11 +497,11 @@ function build_explicit_observed_function(sys, ts; end dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds) if inputs === nothing - args = [dvs, ps..., ivs...] + args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...] else inputs = unwrap.(inputs) ipts = DestructuredArgs(inputs, inbounds = !checkbounds) - args = [dvs, ipts, ps..., ivs...] + args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...] end pre = get_postprocess_fbody(sys) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 75c8a7e235..899bba4aa5 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -21,18 +21,19 @@ end ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false) -const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}} +const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} const UnknownIndexMap = Dict{ - Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}} + BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} struct IndexCache unknown_idx::UnknownIndexMap - discrete_idx::ParamIndexMap + discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}} tunable_idx::ParamIndexMap constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap - discrete_buffer_sizes::Vector{BufferTemplate} + observed_syms::Set{BasicSymbolic} + discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} dependent_buffer_sizes::Vector{BufferTemplate} @@ -48,17 +49,14 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks usym = unwrap(sym) + rsym = renamespace(sys, usym) sym_idx = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end unk_idxs[usym] = sym_idx - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - name = getname(usym) - unk_idxs[name] = sym_idx - symbol_to_variable[name] = sym - end + unk_idxs[rsym] = sym_idx idx += length(sym) end for sym in unks @@ -71,22 +69,28 @@ function IndexCache(sys::AbstractSystem) if idxs == idxs[begin]:idxs[end] idxs = reshape(idxs[begin]:idxs[end], size(idxs)) end + rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs - if hasname(arrsym) - name = getname(arrsym) - unk_idxs[name] = idxs - symbol_to_variable[name] = arrsym - end + unk_idxs[rsym] = idxs end end + observed_syms = Set{Union{Symbol, BasicSymbolic}}() for eq in observed(sys) - if symbolic_type(eq.lhs) != NotSymbolic() && hasname(eq.lhs) - symbol_to_variable[getname(eq.lhs)] = eq.lhs + if symbolic_type(eq.lhs) != NotSymbolic() + sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + push!(observed_syms, sym) + push!(observed_syms, ttsym) + push!(observed_syms, rsym) + push!(observed_syms, rttsym) end end - disc_buffers = Dict{Any, Set{BasicSymbolic}}() + disc_buffers = Dict{Int, Dict{Any, Set{BasicSymbolic}}}() + disc_clocks = Dict{Union{Symbol, BasicSymbolic}, Int}() tunable_buffers = Dict{Any, Set{BasicSymbolic}}() constant_buffers = Dict{Any, Set{BasicSymbolic}}() dependent_buffers = Dict{Any, Set{BasicSymbolic}}() @@ -99,27 +103,123 @@ function IndexCache(sys::AbstractSystem) push!(buf, sym) end + if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing + syss, inputs, continuous_id, _ = get_discrete_subsystems(sys) + + for (i, (inps, disc_sys)) in enumerate(zip(inputs, syss)) + i == continuous_id && continue + disc_buffers[i] = Dict{Any, Set{BasicSymbolic}}() + + for inp in inps + inp = unwrap(inp) + ttinp = default_toterm(inp) + rinp = renamespace(sys, inp) + rttinp = renamespace(sys, ttinp) + is_parameter(sys, inp) || + error("Discrete subsystem $i input $inp is not a parameter") + + disc_clocks[inp] = i + disc_clocks[ttinp] = i + disc_clocks[rinp] = i + disc_clocks[rttinp] = i + + insert_by_type!(disc_buffers[i], inp) + end + + for sym in unknowns(disc_sys) + sym = unwrap(sym) + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + is_parameter(sys, sym) || + error("Discrete subsystem $i unknown $sym is not a parameter") + + disc_clocks[sym] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i + + insert_by_type!(disc_buffers[i], sym) + end + t = get_iv(sys) + for eq in observed(disc_sys) + # TODO: Is this a valid check + # FIXME: This shouldn't be necessary + eq.rhs === -0.0 && continue + sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + if iscall(sym) && operation(sym) == Shift(t, 1) + sym = only(arguments(sym)) + end + disc_clocks[sym] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i + end + end + + for par in inputs[continuous_id] + is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") + par = unwrap(par) + ttpar = default_toterm(par) + rpar = renamespace(sys, par) + rttpar = renamespace(sys, ttpar) + iscall(par) && operation(par) isa Hold || + error("Continuous subsystem input is not a Hold") + if haskey(disc_clocks, par) + sym = par + else + sym = first(arguments(par)) + end + haskey(disc_clocks, sym) || + error("Variable $par not part of a discrete subsystem") + disc_clocks[par] = disc_clocks[sym] + disc_clocks[ttpar] = disc_clocks[sym] + disc_clocks[rpar] = disc_clocks[sym] + disc_clocks[rttpar] = disc_clocks[sym] + insert_by_type!(disc_buffers[disc_clocks[sym]], par) + end + end + affs = vcat(affects(continuous_events(sys)), affects(discrete_events(sys))) + user_affect_clock = maximum(values(disc_clocks); init = 0) + 1 for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) || continue - insert_by_type!(disc_buffers, affect.lhs) + sym = affect.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + + disc_clocks[sym] = user_affect_clock + disc_clocks[ttsym] = user_affect_clock + disc_clocks[rsym] = user_affect_clock + disc_clocks[rttsym] = user_affect_clock + + buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) + insert_by_type!(buffer, affect.lhs) else discs = discretes(affect) for disc in discs is_parameter(sys, disc) || error("Expected discrete variable $disc in callback to be a parameter") - insert_by_type!(disc_buffers, disc) + disc = unwrap(disc) + ttdisc = default_toterm(disc) + rdisc = renamespace(sys, disc) + rttdisc = renamespace(sys, ttdisc) + disc_clocks[disc] = user_affect_clock + disc_clocks[ttdisc] = user_affect_clock + disc_clocks[rdisc] = user_affect_clock + disc_clocks[rttdisc] = user_affect_clock + + buffer = get!( + disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) + insert_by_type!(buffer, disc) end end end - if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing - _, inputs, continuous_id, _ = get_discrete_subsystems(sys) - for par in inputs[continuous_id] - is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") - insert_by_type!(disc_buffers, par) - end - end if has_parameter_dependencies(sys) pdeps = parameter_dependencies(sys) @@ -132,13 +232,11 @@ function IndexCache(sys::AbstractSystem) for p in parameters(sys) p = unwrap(p) ctype = symtype(p) - haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue + haskey(disc_clocks, p) && continue haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if is_discrete_domain(p) - disc_buffers - elseif istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown() + if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown() tunable_buffers else constant_buffers @@ -150,30 +248,63 @@ function IndexCache(sys::AbstractSystem) ) end + disc_idxs = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}}() + disc_buffer_sizes = [BufferTemplate[] for _ in 1:length(disc_buffers)] + disc_buffer_types = Set() + for buffer in values(disc_buffers) + union!(disc_buffer_types, keys(buffer)) + end + + for (clockidx, buffer) in disc_buffers + for (i, btype) in enumerate(disc_buffer_types) + if !haskey(buffer, btype) + push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, 0)) + continue + end + push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, length(buffer[btype]))) + for (j, sym) in enumerate(buffer[btype]) + disc_idxs[sym] = (clockidx, i, j) + disc_idxs[default_toterm(sym)] = (clockidx, i, j) + end + end + end + for (sym, clockid) in disc_clocks + haskey(disc_idxs, sym) && continue + disc_idxs[sym] = (clockid, 0, 0) + disc_idxs[default_toterm(sym)] = (clockid, 0, 0) + end + function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}}) idxs = ParamIndexMap() buffer_sizes = BufferTemplate[] for (i, (T, buf)) in enumerate(buffers) for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) idxs[p] = (i, j) - idxs[default_toterm(p)] = (i, j) - if hasname(p) && (!iscall(p) || operation(p) !== getindex) - idxs[getname(p)] = (i, j) - symbol_to_variable[getname(p)] = p - idxs[getname(default_toterm(p))] = (i, j) - symbol_to_variable[getname(default_toterm(p))] = p - end + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) end push!(buffer_sizes, BufferTemplate(T, length(buf))) end return idxs, buffer_sizes end - disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers) + tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers) const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers) + for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), + keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs), + observed_syms, independent_variable_symbols(sys))) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end + return IndexCache( unk_idxs, disc_idxs, @@ -181,7 +312,8 @@ function IndexCache(sys::AbstractSystem) const_idxs, dependent_idxs, nonnumeric_idxs, - discrete_buffer_sizes, + observed_syms, + disc_buffer_sizes, tunable_buffer_sizes, const_buffer_sizes, dependent_buffer_sizes, @@ -191,14 +323,26 @@ function IndexCache(sys::AbstractSystem) end function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.unknown_idx, sym) !== nothing end function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end return check_index_map(ic.unknown_idx, sym) end function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.tunable_idx, sym) !== nothing || check_index_map(ic.discrete_idx, sym) !== nothing || check_index_map(ic.constant_idx, sym) !== nothing || @@ -208,7 +352,8 @@ end function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) if sym isa Symbol - sym = ic.symbol_to_variable[sym] + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing end validate_size = Symbolics.isarraysymbolic(sym) && Symbolics.shape(sym) !== Symbolics.Unknown() @@ -227,6 +372,25 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) end end +function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end + return check_index_map(ic.discrete_idx, sym) !== nothing +end + +function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end + idx = check_index_map(ic.discrete_idx, sym) + idx === nothing && return nothing + clockid, partitionid... = idx + return ParameterTimeseriesIndex(clockid, partitionid) +end + function check_index_map(idxmap, sym) if (idx = get(idxmap, sym, nothing)) !== nothing return idx @@ -249,10 +413,14 @@ end function discrete_linear_index(ic::IndexCache, idx::ParameterIndex) idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected") ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0) + for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1) + ind += sum(temp.length for temp in clockbuftemps; init = 0) + end ind += sum( - temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1); + temp.length + for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1); init = 0) - ind += idx.idx[2] + ind += idx.idx[3] return ind end @@ -271,30 +439,31 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.tunable_buffer_sizes) disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.discrete_buffer_sizes) + for temp in Iterators.flatten(ic.discrete_buffer_sizes)) const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.constant_buffer_sizes) dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.dependent_buffer_sizes) nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.nonnumeric_buffer_sizes) - for p in ps + p = unwrap(p) if haskey(ic.discrete_idx, p) - i, j = ic.discrete_idx[p] - disc_buf[i][j] = unwrap(p) + disc_offset = length(first(ic.discrete_buffer_sizes)) + i, j, k = ic.discrete_idx[p] + disc_buf[(i - 1) * disc_offset + j][k] = p elseif haskey(ic.tunable_idx, p) i, j = ic.tunable_idx[p] - param_buf[i][j] = unwrap(p) + param_buf[i][j] = p elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] - const_buf[i][j] = unwrap(p) + const_buf[i][j] = p elseif haskey(ic.dependent_idx, p) i, j = ic.dependent_idx[p] - dep_buf[i][j] = unwrap(p) + dep_buf[i][j] = p elseif haskey(ic.nonnumeric_idx, p) i, j = ic.nonnumeric_idx[p] - nonnumeric_buf[i][j] = unwrap(p) + nonnumeric_buf[i][j] = p else error("Invalid parameter $p") end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 1cc944e1d9..43ccdb7e56 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -100,8 +100,11 @@ function MTKParameters( tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.tunable_buffer_sizes) - disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) - for temp in ic.discrete_buffer_sizes) + disc_buffer = SizedArray{Tuple{length(ic.discrete_buffer_sizes)}}([Tuple(Vector{temp.type}( + undef, + temp.length) + for temp in subbuffer_sizes) + for subbuffer_sizes in ic.discrete_buffer_sizes]) const_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes) dep_buffer = Tuple(Vector{temp.type}(undef, temp.length) @@ -114,8 +117,8 @@ function MTKParameters( i, j = ic.tunable_idx[sym] tunable_buffer[i][j] = val elseif haskey(ic.discrete_idx, sym) - i, j = ic.discrete_idx[sym] - disc_buffer[i][j] = val + i, j, k = ic.discrete_idx[sym] + disc_buffer[i][j][k] = val elseif haskey(ic.constant_idx, sym) i, j = ic.constant_idx[sym] const_buffer[i][j] = val @@ -132,7 +135,6 @@ function MTKParameters( end return done end - for (sym, val) in p sym = unwrap(sym) val = unwrap(val) @@ -156,7 +158,7 @@ function MTKParameters( end end tunable_buffer = narrow_buffer_type.(tunable_buffer) - disc_buffer = narrow_buffer_type.(disc_buffer) + disc_buffer = broadcast.(narrow_buffer_type, disc_buffer) const_buffer = narrow_buffer_type.(const_buffer) # Don't narrow nonnumeric types nonnumeric_buffer = nonnumeric_buffer @@ -220,11 +222,16 @@ function _split_helper(buf_v::T, recurse, raw, idx) where {T} _split_helper(eltype(T), buf_v, recurse, raw, idx) end -function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{true}, raw, idx) - map(b -> _split_helper(eltype(b), b, Val(false), raw, idx), buf_v) +function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{N}, raw, idx) where {N} + map(b -> _split_helper(eltype(b), b, Val(N - 1), raw, idx), buf_v) +end + +function _split_helper(::Type{<:AbstractArray}, buf_v::Tuple, ::Val{N}, raw, idx) where {N} + ntuple(i -> _split_helper(eltype(buf_v[i]), buf_v[i], Val(N - 1), raw, idx), + Val(length(buf_v))) end -function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{false}, raw, idx) +function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{0}, raw, idx) _split_helper((), buf_v, (), raw, idx) end @@ -234,7 +241,7 @@ function _split_helper(_, buf_v, _, raw, idx) return res end -function split_into_buffers(raw::AbstractArray, buf, recurse = Val(true)) +function split_into_buffers(raw::AbstractArray, buf, recurse = Val(1)) idx = Ref(1) ntuple(i -> _split_helper(buf[i], recurse, raw, idx), Val(length(buf))) end @@ -262,10 +269,10 @@ SciMLStructures.isscimlstructure(::MTKParameters) = true SciMLStructures.ismutablescimlstructure(::MTKParameters) = true -for (Portion, field) in [(SciMLStructures.Tunable, :tunable) - (SciMLStructures.Discrete, :discrete) - (SciMLStructures.Constants, :constant) - (Nonnumeric, :nonnumeric)] +for (Portion, field, recurse) in [(SciMLStructures.Tunable, :tunable, 1) + (SciMLStructures.Discrete, :discrete, 2) + (SciMLStructures.Constants, :constant, 1) + (Nonnumeric, :nonnumeric, 1)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let as_vector = as_vector, p = p @@ -283,7 +290,7 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) end @eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals) - @set! p.$field = split_into_buffers(newvals, p.$field) + @set! p.$field = split_into_buffers(newvals, p.$field, Val($recurse)) if p.dependent_update_oop !== nothing raw = p.dependent_update_oop(p...) @set! p.dependent = split_into_buffers(raw, p.dependent, Val(false)) @@ -302,7 +309,8 @@ end function Base.copy(p::MTKParameters) tunable = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.tunable) - discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete) + discrete = typeof(p.discrete)([Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) + for buf in clockbuf) for clockbuf in p.discrete]) constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant) dependent = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.dependent) nonnumeric = copy.(p.nonnumeric) @@ -323,7 +331,8 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::Para if portion isa SciMLStructures.Tunable return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...] elseif portion isa SciMLStructures.Discrete - return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...] + k, l... = k + return isempty(l) ? p.discrete[i][j][k] : p.discrete[i][j][k][l...] elseif portion isa SciMLStructures.Constants return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...] elseif portion === DEPENDENT_PORTION @@ -349,13 +358,14 @@ function SymbolicIndexingInterface.set_parameter!( p.tunable[i][j][k...] = val end elseif portion isa SciMLStructures.Discrete - if isempty(k) - if validate_size && size(val) !== size(p.discrete[i][j]) - throw(InvalidParameterSizeException(size(p.discrete[i][j]), size(val))) + k, l... = k + if isempty(l) + if validate_size && size(val) !== size(p.discrete[i][j][k]) + throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) end - p.discrete[i][j] = val + p.discrete[i][j][k] = val else - p.discrete[i][j][k...] = val + p.discrete[i][j][k][l...] = val end elseif portion isa SciMLStructures.Constants if isempty(k) @@ -393,10 +403,11 @@ function _set_parameter_unchecked!( p.tunable[i][j][k...] = val end elseif portion isa SciMLStructures.Discrete - if isempty(k) - p.discrete[i][j] = val + k, l... = k + if isempty(l) + p.discrete[i][j][k] = val else - p.discrete[i][j][k...] = val + p.discrete[i][j][k][l...] = val end elseif portion isa SciMLStructures.Constants if isempty(k) @@ -499,8 +510,10 @@ end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict) newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf)) for buf in oldbuf.tunable) - @set! newbuf.discrete = Tuple(Vector{Any}(undef, length(buf)) - for buf in newbuf.discrete) + @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([Tuple(Vector{Any}(undef, + length(buf)) + for buf in clockbuf) + for clockbuf in newbuf.discrete]) @set! newbuf.constant = Tuple(Vector{Any}(undef, length(buf)) for buf in newbuf.constant) @set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf)) @@ -542,8 +555,11 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.( oldbuf.tunable, newbuf.tunable) - @set! newbuf.discrete = narrow_buffer_type_and_fallback_undefs.( - oldbuf.discrete, newbuf.discrete) + @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([narrow_buffer_type_and_fallback_undefs.( + oldclockbuf, + newclockbuf) + for (oldclockbuf, newclockbuf) in zip( + oldbuf.discrete, newbuf.discrete)]) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( oldbuf.constant, newbuf.constant) @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( @@ -552,11 +568,61 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va @set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.( oldbuf.dependent, split_into_buffers( - newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false))) + newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(0))) end return newbuf end +struct NestedGetIndex{T} + x::T +end + +function Base.getindex(ngi::NestedGetIndex, idx::Tuple) + i, j, k... = idx + return ngi.x[i][j][k...] +end + +# Required for DiffEqArray constructor to work during interpolation +Base.size(::NestedGetIndex) = () + +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + ::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where { + A, B <: NestedGetIndex} + for (i, val) in args + ps.discrete[i] = val.x + end + return ps +end + +function SciMLBase.create_parameter_timeseries_collection( + sys::AbstractSystem, ps::MTKParameters, tspan) + ic = get_index_cache(sys) # this exists because the parameters are `MTKParameters` + has_discrete_subsystems(sys) || return nothing + (dss = get_discrete_subsystems(sys)) === nothing && return nothing + _, _, _, id_to_clock = dss + buffers = [] + + for (i, partition) in enumerate(ps.discrete) + clock = id_to_clock[i] + @match clock begin + PeriodicClock(dt, _...) => begin + ts = tspan[1]:(dt):tspan[2] + push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) + end + &SolverStepClock => push!(buffers, + DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1))) + &Continuous => continue + _ => error("Unhandled clock $clock") + end + end + + return ParameterTimeseriesCollection(Tuple(buffers), copy(ps)) +end + +function SciMLBase.get_saveable_values(ps::MTKParameters, timeseries_idx) + return NestedGetIndex(deepcopy(ps.discrete[timeseries_idx])) +end + function DiffEqBase.anyeltypedual( p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter} DiffEqBase.anyeltypedual(p.tunable) @@ -582,8 +648,10 @@ function Base.getindex(buf::MTKParameters, i) i -= _num_subarrays(buf.tunable) end if !isempty(buf.discrete) - i <= _num_subarrays(buf.discrete) && return _subarrays(buf.discrete)[i] - i -= _num_subarrays(buf.discrete) + for clockbuf in buf.discrete + i <= _num_subarrays(clockbuf) && return _subarrays(clockbuf)[i] + i -= _num_subarrays(clockbuf) + end end if !isempty(buf.constant) i <= _num_subarrays(buf.constant) && return _subarrays(buf.constant)[i] @@ -612,7 +680,7 @@ function Base.setindex!(p::MTKParameters, val, i) end done end - _helper(p.tunable) || _helper(p.discrete) || _helper(p.constant) || + _helper(p.tunable) || _helper(Iterators.flatten(p.discrete)) || _helper(p.constant) || _helper(p.nonnumeric) || throw(BoundsError(p, i)) if p.dependent_update_iip !== nothing p.dependent_update_iip(ArrayPartition(p.dependent), p...) @@ -620,26 +688,7 @@ function Base.setindex!(p::MTKParameters, val, i) end function Base.getindex(p::MTKParameters, pind::ParameterIndex) - (; portion, idx) = pind - i, j, k... = idx - if isempty(k) - indexer = (v) -> v[i][j] - else - indexer = (v) -> v[i][j][k...] - end - if portion isa SciMLStructures.Tunable - indexer(p.tunable) - elseif portion isa SciMLStructures.Discrete - indexer(p.discrete) - elseif portion isa SciMLStructures.Constants - indexer(p.constant) - elseif portion === DEPENDENT_PORTION - indexer(p.dependent) - elseif portion === NONNUMERIC_PORTION - indexer(p.nonnumeric) - else - error("Unhandled portion ", portion) - end + parameter_values(p, pind) end function Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) @@ -649,7 +698,9 @@ end function Base.iterate(buf::MTKParameters, state = 1) total_len = 0 total_len += _num_subarrays(buf.tunable) - total_len += _num_subarrays(buf.discrete) + for clockbuf in buf.discrete + total_len += _num_subarrays(clockbuf) + end total_len += _num_subarrays(buf.constant) total_len += _num_subarrays(buf.nonnumeric) total_len += _num_subarrays(buf.dependent) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index ff26552c79..2cbf820d0d 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -8,6 +8,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, isparameter, isconstant, independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, + InferredTimeDomain, VariableType, getvariabletype, has_equations, ODESystem using ..BipartiteGraphs import ..BipartiteGraphs: invview, complete @@ -331,7 +332,7 @@ function TearingState(sys; quick_cancel = false, check = true) !isdifferential(var) && (it = input_timedomain(var)) !== nothing set_incidence = false var = only(arguments(var)) - var = setmetadata(var, TimeDomain, it) + var = setmetadata(var, VariableTimeDomain, it) @goto ANOTHER_VAR end end @@ -660,7 +661,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals @set! sys.defaults = merge(ModelingToolkit.defaults(sys), Dict(v => 0.0 for v in Iterators.flatten(inputs))) end - ps = [setmetadata(sym, TimeDomain, get(time_domains, sym, Continuous())) + ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous)) for sym in get_ps(sys)] @set! sys.ps = ps else diff --git a/src/utils.jl b/src/utils.jl index 5d6af76b77..acd7a8686d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -370,7 +370,9 @@ function vars(exprs::Symbolic; op = Differential) end vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op) -vars(exprs; op = Differential) = foldl((x, y) -> vars!(x, y; op = op), exprs; init = Set()) +function vars(exprs; op = Differential) + foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set()) +end vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) diff --git a/test/clock.jl b/test/clock.jl index 86967365ad..5bf5e917aa 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -14,7 +14,7 @@ dt = 0.1 @parameters kp # u(n + 1) := f(u(n)) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) r ~ 1.0 @@ -64,40 +64,41 @@ By inference: ci, varmap = infer_clocks(sys) eqmap = ci.eq_domain -tss, inputs = ModelingToolkit.split_system(deepcopy(ci)) -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) +tss, inputs, continuous_id = ModelingToolkit.split_system(deepcopy(ci)) +sss, = ModelingToolkit._structural_simplify!( + deepcopy(tss[continuous_id]), (inputs[continuous_id], ())) @test equations(sss) == [D(x) ~ u - x] -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) +sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test isempty(equations(sss)) -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) -@test observed(sss) == [yd(k + 1) ~ Sample(t, dt)(y); r(k + 1) ~ 1.0; +@test observed(sss) == [yd(k + 1) ~ Sample(dt)(y); r(k + 1) ~ 1.0; ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))] -d = Clock(t, dt) +d = Clock(dt) # Note that TearingState reorders the equations -@test eqmap[1] == Continuous() +@test eqmap[1] == Continuous @test eqmap[2] == d @test eqmap[3] == d @test eqmap[4] == d -@test eqmap[5] == Continuous() -@test eqmap[6] == Continuous() +@test eqmap[5] == Continuous +@test eqmap[6] == Continuous @test varmap[yd] == d @test varmap[ud] == d @test varmap[r] == d -@test varmap[x] == Continuous() -@test varmap[y] == Continuous() -@test varmap[u] == Continuous() +@test varmap[x] == Continuous +@test varmap[y] == Continuous +@test varmap[u] == Continuous @info "Testing shift normalization" dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) @parameters kp -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * yd + ud(k - 2) # plant (time continuous part) @@ -170,10 +171,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) - ud1 ~ kp * (Sample(t, dt)(r) - yd1) - yd2 ~ Sample(t, dt2)(y) - ud2 ~ kp * (Sample(t, dt2)(r) - yd2) + yd1 ~ Sample(dt)(y) + ud1 ~ kp * (Sample(dt)(r) - yd1) + yd2 ~ Sample(dt2)(y) + ud2 ~ kp * (Sample(dt2)(r) - yd2) # plant (time continuous part) u ~ Hold(ud1) + Hold(ud2) @@ -182,8 +183,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named sys = ODESystem(eqs, t) ci, varmap = infer_clocks(sys) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) #@test get_eq_domain(eqs[1]) == d #@test get_eq_domain(eqs[3]) == d2 @@ -191,15 +192,15 @@ eqs = [yd ~ Sample(t, dt)(y) @test varmap[ud1] == d @test varmap[yd2] == d2 @test varmap[ud2] == d2 - @test varmap[r] == Continuous() - @test varmap[x] == Continuous() - @test varmap[y] == Continuous() - @test varmap[u] == Continuous() + @test varmap[r] == Continuous + @test varmap[x] == Continuous + @test varmap[y] == Continuous + @test varmap[u] == Continuous @info "test composed systems" dt = 0.5 - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) timevec = 0:0.1:4 @@ -239,16 +240,16 @@ eqs = [yd ~ Sample(t, dt)(y) ci, varmap = infer_clocks(cl) - @test varmap[f.x] == Clock(t, 0.5) - @test varmap[p.x] == Continuous() - @test varmap[p.y] == Continuous() - @test varmap[c.ud] == Clock(t, 0.5) - @test varmap[c.yd] == Clock(t, 0.5) - @test varmap[c.y] == Continuous() - @test varmap[f.y] == Clock(t, 0.5) - @test varmap[f.u] == Clock(t, 0.5) - @test varmap[p.u] == Continuous() - @test varmap[c.r] == Clock(t, 0.5) + @test varmap[f.x] == Clock(0.5) + @test varmap[p.x] == Continuous + @test varmap[p.y] == Continuous + @test varmap[c.ud] == Clock(0.5) + @test varmap[c.yd] == Clock(0.5) + @test varmap[c.y] == Continuous + @test varmap[f.y] == Clock(0.5) + @test varmap[f.u] == Clock(0.5) + @test varmap[p.u] == Continuous + @test varmap[c.r] == Clock(0.5) ## Multiple clock rates @info "Testing multi-rate hybrid system" @@ -259,10 +260,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) + yd1 ~ Sample(dt)(y) ud1 ~ kp * (r - yd1) # controller (time discrete part `dt=0.2`) - yd2 ~ Sample(t, dt2)(y) + yd2 ~ Sample(dt2)(y) ud2 ~ kp * (r - yd2) # plant (time continuous part) @@ -272,8 +273,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named cl = ODESystem(eqs, t) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) ci, varmap = infer_clocks(cl) @test varmap[yd1] == d @@ -330,8 +331,8 @@ eqs = [yd ~ Sample(t, dt)(y) using ModelingToolkitStandardLibrary.Blocks dt = 0.05 - d = Clock(t, dt) - k = ShiftIndex() + d = Clock(dt) + k = ShiftIndex(d) @mtkmodel DiscretePI begin @components begin @@ -361,7 +362,7 @@ eqs = [yd ~ Sample(t, dt)(y) output = RealOutput() end @equations begin - output.u ~ Sample(t, dt)(input.u) + output.u ~ Sample(dt)(input.u) end end @@ -473,7 +474,7 @@ eqs = [yd ~ Sample(t, dt)(y) ## Test continuous clock - c = ModelingToolkit.SolverStepClock(t) + c = ModelingToolkit.SolverStepClock k = ShiftIndex(c) @mtkmodel CounterSys begin diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 30bbb27ede..b3b170df18 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -2,6 +2,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters using SymbolicIndexingInterface using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants +using StaticArrays: SizedVector using OrdinaryDiffEq using ForwardDiff using JET @@ -307,3 +308,32 @@ end newoprob = remake(oprob_scal_scal; p = ps_vec) @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] end + +# Parameter timeseries +ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]), + (), (), (), nothing, nothing) +with_updated_parameter_timeseries_values( + sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) +@test ps.discrete[1][1] == [5.0, 10.0] +with_updated_parameter_timeseries_values( + sys, ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)), + 2 => ModelingToolkit.NestedGetIndex(([4.0, 40.0],))) +@test ps.discrete[1][1] == [3.0, 30.0] +@test ps.discrete[2][1] == [4.0, 40.0] +@test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1] + +# With multiple types and clocks +ps = MTKParameters( + (), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), + (), (), (), nothing, nothing) +@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector} +tsidx1 = 1 +tsidx2 = 2 +@test length(ps.discrete[tsidx1][1]) == 3 +@test length(ps.discrete[tsidx1][2]) == 1 +@test length(ps.discrete[tsidx2][1]) == 3 +@test length(ps.discrete[tsidx2][2]) == 0 +with_updated_parameter_timeseries_values( + sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) +@test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0] +@test ps.discrete[tsidx1][2][] == false diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index ef446e9630..fc03f53d74 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -157,10 +157,10 @@ end dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t) @parameters kp kq - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) - eqs = [yd ~ Sample(t, dt)(y) + eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) + kq * z r ~ 1.0 u ~ Hold(ud) @@ -175,7 +175,7 @@ end prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) - @test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent) + @test_nowarn solve(prob, Tsit5()) @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], discrete_events = [[0.5] => [kp ~ 2.0]]) @@ -184,11 +184,11 @@ end yd(k - 2) => 2.0]) @test prob.ps[kp] == 1.0 @test prob.ps[kq] == 2.0 - @test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) + @test_nowarn solve(prob, Tsit5()) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) - integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent) + integ = init(prob, Tsit5()) @test integ.ps[kp] == 1.0 @test integ.ps[kq] == 2.0 step!(integ, 0.6) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index f707959135..01011828ab 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -4,6 +4,7 @@ using OrdinaryDiffEq using ModelingToolkit: t_nounits as t, D_nounits as D using ModelingToolkit: MTKParameters, ParameterIndex, DEPENDENT_PORTION, NONNUMERIC_PORTION using SciMLStructures: Tunable, Discrete, Constants +using StaticArrays: SizedVector x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)] @@ -194,7 +195,7 @@ S = get_sensitivity(closed_loop, :u) @testset "Indexing MTKParameters with ParameterIndex" begin ps = MTKParameters(([1.0, 2.0], [3, 4]), - ([true, false], [[1 2; 3 4]]), + SizedVector{2}([([true, false], [[1 2; 3 4]]), ([false, true], [[2 4; 6 8]])]), ([5, 6],), ([7.0, 8.0],), (["hi", "bye"], [:lie, :die]), @@ -202,14 +203,14 @@ S = get_sensitivity(closed_loop, :u) nothing) @test ps[ParameterIndex(Tunable(), (1, 2))] === 2.0 @test ps[ParameterIndex(Tunable(), (2, 2))] === 4 - @test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 4 - @test ps[ParameterIndex(Discrete(), (2, 1))] == [1 2; 3 4] + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 4 + @test ps[ParameterIndex(Discrete(), (2, 2, 1))] == [2 4; 6 8] @test ps[ParameterIndex(Constants(), (1, 1))] === 5 @test ps[ParameterIndex(DEPENDENT_PORTION, (1, 1))] === 7.0 @test ps[ParameterIndex(NONNUMERIC_PORTION, (2, 2))] === :die ps[ParameterIndex(Tunable(), (1, 2))] = 3.0 - ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] = 5 + ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] = 5 @test ps[ParameterIndex(Tunable(), (1, 2))] === 3.0 - @test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 5 + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 5 end diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 7fd57c0474..10d24fd6f2 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -1,90 +1,161 @@ using ModelingToolkit, SymbolicIndexingInterface, SciMLBase -using ModelingToolkit: t_nounits as t, D_nounits as D +using ModelingToolkit: t_nounits as t, D_nounits as D, ParameterIndex +using SciMLStructures: Tunable + +@testset "ODESystem" begin + @parameters a b + @variables x(t)=1.0 y(t)=2.0 xy(t) + eqs = [D(x) ~ a * y + t, D(y) ~ b * t] + @named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y]) + odesys = complete(odesys) + @test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y])) + @test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b])) + @test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == + [1, 2, nothing, nothing, nothing, 1, 2, 1, 2, nothing, nothing] + @test isequal(variable_symbols(odesys), [x, y]) + @test all(is_parameter.((odesys,), [a, b, ParameterIndex(Tunable(), (1, 1)), :a, :b])) + @test all(.!is_parameter.((odesys,), [x, y, t, 3, 0, :x, :y])) + @test parameter_index(odesys, a) == parameter_index(odesys, :a) + @test parameter_index(odesys, a) isa ParameterIndex{Tunable, Tuple{Int, Int}} + @test parameter_index(odesys, b) == parameter_index(odesys, :b) + @test parameter_index(odesys, b) isa ParameterIndex{Tunable, Tuple{Int, Int}} + @test parameter_index.( + (odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y]) == + [nothing, nothing, nothing, ParameterIndex(Tunable(), (1, 1)), nothing, nothing] + @test isequal(parameter_symbols(odesys), [a, b]) + @test all(is_independent_variable.((odesys,), [t, :t])) + @test all(.!is_independent_variable.((odesys,), [x, y, a, :x, :y, :a])) + @test isequal(independent_variable_symbols(odesys), [t]) + @test is_time_dependent(odesys) + @test constant_structure(odesys) + @test !isempty(default_values(odesys)) + @test default_values(odesys)[x] == 1.0 + @test default_values(odesys)[y] == 2.0 + @test isequal(default_values(odesys)[xy], x + y) + + @named odesys = ODESystem( + eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y]) + odesys = complete(odesys) + @test default_values(odesys)[xy] == 3.0 + pobs = parameter_observed(odesys, a + b) + @test isempty(get_all_timeseries_indexes(odesys, a + b)) + @test pobs( + ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ 3.0 + pobs = parameter_observed(odesys, [a + b, a - b]) + @test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b])) + @test pobs( + ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ [3.0, -1.0] +end + +# @testset "Clock system" begin +# dt = 0.1 +# dt2 = 0.2 +# @variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0 +# @parameters kp=1 r=1 + +# eqs = [ +# # controller (time discrete part `dt=0.1`) +# yd1 ~ Sample(t, dt)(y) +# ud1 ~ kp * (r - yd1) +# # controller (time discrete part `dt=0.2`) +# yd2 ~ Sample(t, dt2)(y) +# ud2 ~ kp * (r - yd2) + +# # plant (time continuous part) +# u ~ Hold(ud1) + Hold(ud2) +# D(x) ~ -x + u +# y ~ x] + +# @mtkbuild cl = ODESystem(eqs, t) +# partition1_params = [Hold(ud1), Sample(t, dt)(y), ud1, yd1] +# partition2_params = [Hold(ud2), Sample(t, dt2)(y), ud2, yd2] +# @test all( +# Base.Fix1(is_timeseries_parameter, cl), vcat(partition1_params, partition2_params)) +# @test allequal(timeseries_parameter_index(cl, p).timeseries_idx +# for p in partition1_params) +# @test allequal(timeseries_parameter_index(cl, p).timeseries_idx +# for p in partition2_params) +# tsidx1 = timeseries_parameter_index(cl, partition1_params[1]).timeseries_idx +# tsidx2 = timeseries_parameter_index(cl, partition2_params[1]).timeseries_idx +# @test tsidx1 != tsidx2 +# ps = ModelingToolkit.MTKParameters(cl, [kp => 1.0, Sample(t, dt)(y) => 1.0]) +# pobs = parameter_observed(cl, Shift(t, 1)(yd1)) +# @test pobs.timeseries_idx == tsidx1 +# @test pobs.observed_fn(ps, 0.0) == 1.0 +# pobs = parameter_observed(cl, [Shift(t, 1)(yd1), Shift(t, 1)(ud1)]) +# @test pobs.timeseries_idx == tsidx1 +# @test pobs.observed_fn(ps, 0.0) == [1.0, 0.0] +# pobs = parameter_observed(cl, [Shift(t, 1)(yd1), Shift(t, 1)(ud2)]) +# @test pobs.timeseries_idx === nothing +# @test pobs.observed_fn(ps, 0.0) == [1.0, 1.0] +# end + +@testset "Nonlinear system" begin + @variables x y z + @parameters σ ρ β + + eqs = [0 ~ σ * (y - x), + 0 ~ x * (ρ - z) - y, + 0 ~ x * y - β * z] + @named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β]) + ns = complete(ns) + @test !is_time_dependent(ns) + ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0]) + pobs = parameter_observed(ns, σ + ρ) + @test isempty(get_all_timeseries_indexes(ns, σ + ρ)) + @test pobs(ps) == 3.0 + pobs = parameter_observed(ns, [σ + ρ, ρ + β]) + @test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β])) + @test pobs(ps) == [3.0, 5.0] +end + +@testset "PDESystem" begin + @parameters x + @variables u(..) + Dxx = Differential(x)^2 + Dtt = Differential(t)^2 + Dt = D + + #2D PDE + C = 1 + eq = Dtt(u(t, x)) ~ C^2 * Dxx(u(t, x)) + + # Initial and boundary conditions + bcs = [u(t, 0) ~ 0.0,# for all t > 0 + u(t, 1) ~ 0.0,# for all t > 0 + u(0, x) ~ x * (1.0 - x), #for all 0 < x < 1 + Dt(u(0, x)) ~ 0.0] #for all 0 < x < 1] + + # Space and time domains + domains = [t ∈ (0.0, 1.0), + x ∈ (0.0, 1.0)] -@parameters a b -@variables x(t)=1.0 y(t)=2.0 xy(t) -eqs = [D(x) ~ a * y + t, D(y) ~ b * t] -@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y]) - -@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y])) -@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b])) -@test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == - [1, 2, nothing, nothing, nothing, 1, 2, 1, 2, nothing, nothing] -@test isequal(variable_symbols(odesys), [x, y]) -@test all(is_parameter.((odesys,), [a, b, 1, 2, :a, :b])) -@test all(.!is_parameter.((odesys,), [x, y, t, 3, 0, :x, :y])) -@test parameter_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == - [nothing, nothing, 1, 2, nothing, 1, 2, nothing, nothing, 1, 2] -@test isequal(parameter_symbols(odesys), [a, b]) -@test all(is_independent_variable.((odesys,), [t, :t])) -@test all(.!is_independent_variable.((odesys,), [x, y, a, :x, :y, :a])) -@test isequal(independent_variable_symbols(odesys), [t]) -@test is_time_dependent(odesys) -@test constant_structure(odesys) -@test !isempty(default_values(odesys)) -@test default_values(odesys)[x] == 1.0 -@test default_values(odesys)[y] == 2.0 -@test isequal(default_values(odesys)[xy], x + y) - -@named odesys = ODESystem( - eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y]) -@test default_values(odesys)[xy] == 3.0 - -@variables x y z -@parameters σ ρ β - -eqs = [0 ~ σ * (y - x), - 0 ~ x * (ρ - z) - y, - 0 ~ x * y - β * z] -@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β]) - -@test !is_time_dependent(ns) - -@parameters x -@variables u(..) -Dxx = Differential(x)^2 -Dtt = Differential(t)^2 -Dt = D - -#2D PDE -C = 1 -eq = Dtt(u(t, x)) ~ C^2 * Dxx(u(t, x)) - -# Initial and boundary conditions -bcs = [u(t, 0) ~ 0.0,# for all t > 0 - u(t, 1) ~ 0.0,# for all t > 0 - u(0, x) ~ x * (1.0 - x), #for all 0 < x < 1 - Dt(u(0, x)) ~ 0.0] #for all 0 < x < 1] - -# Space and time domains -domains = [t ∈ (0.0, 1.0), - x ∈ (0.0, 1.0)] - -@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u]) - -@test pde_system.ps == SciMLBase.NullParameters() -@test parameter_symbols(pde_system) == [] - -@parameters x -@constants h = 1 -@variables u(..) -Dt = D -Dxx = Differential(x)^2 -eq = Dt(u(t, x)) ~ h * Dxx(u(t, x)) -bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x), - u(t, 0) ~ 0, u(t, 1) ~ 0] - -domains = [t ∈ (0.0, 1.0), - x ∈ (0.0, 1.0)] - -analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] -analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) - -@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) - -@test isequal(pdesys.ps, [h]) -@test isequal(parameter_symbols(pdesys), [h]) -@test isequal(parameters(pdesys), [h]) + @named pde_system = PDESystem(eq, bcs, domains, [t, x], [u]) + + @test pde_system.ps == SciMLBase.NullParameters() + @test parameter_symbols(pde_system) == [] + + @parameters x + @constants h = 1 + @variables u(..) + Dt = D + Dxx = Differential(x)^2 + eq = Dt(u(t, x)) ~ h * Dxx(u(t, x)) + bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x), + u(t, 0) ~ 0, u(t, 1) ~ 0] + + domains = [t ∈ (0.0, 1.0), + x ∈ (0.0, 1.0)] + + analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] + analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) + + @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) + + @test isequal(pdesys.ps, [h]) + @test isequal(parameter_symbols(pdesys), [h]) + @test isequal(parameters(pdesys), [h]) +end # Issue#2767 using ModelingToolkit