Skip to content

Commit

Permalink
Merge pull request #2728 from AayushSabharwal/as/discrete-save
Browse files Browse the repository at this point in the history
feat: implement new discrete saving functionality
  • Loading branch information
ChrisRackauckas authored Jul 29, 2024
2 parents 110ebfd + 3a073ec commit 171c43e
Show file tree
Hide file tree
Showing 18 changed files with 865 additions and 533 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Expand Down Expand Up @@ -81,6 +82,7 @@ DocStringExtensions = "0.7, 0.8, 0.9"
DomainSets = "0.6, 0.7"
DynamicQuantities = "^0.11.2, 0.12, 0.13"
ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappersWrappers = "0.1"
Expand All @@ -98,18 +100,18 @@ NonlinearSolve = "3.12"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
PrecompileTools = "1"
RecursiveArrayTools = "2.3, 3"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.28.0"
SciMLBase = "2.46"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.12"
SymbolicIndexingInterface = "0.3.26"
SymbolicUtils = "2.1"
Symbolics = "5.32"
URIs = "1"
Expand Down
16 changes: 8 additions & 8 deletions docs/src/tutorials/SampledData.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ A clock can be seen as an *event source*, i.e., when the clock ticks, an event i
- [`Hold`](@ref)
- [`ShiftIndex`](@ref)

When a continuous-time variable `x` is sampled using `xd = Sample(x, dt)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc.
When a continuous-time variable `x` is sampled using `xd = Sample(dt)(x)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc.

To make a discrete-time variable available to the continuous partition, the [`Hold`](@ref) operator is used. `xc = Hold(xd)` creates a continuous-time variable `xc` that is updated whenever the clock associated with `xd` ticks, and holds its value constant between ticks.

Expand All @@ -34,7 +34,7 @@ using ModelingToolkit
using ModelingToolkit: t_nounits as t
@variables x(t) y(t) u(t)
dt = 0.1 # Sample interval
clock = Clock(t, dt) # A periodic clock with tick rate dt
clock = Clock(dt) # A periodic clock with tick rate dt
k = ShiftIndex(clock)
eqs = [
Expand Down Expand Up @@ -99,7 +99,7 @@ may thus be modeled as
```julia
t = ModelingToolkit.t_nounits
@variables y(t) [description = "Output"] u(t) [description = "Input"]
k = ShiftIndex(Clock(t, dt))
k = ShiftIndex(Clock(dt))
eqs = [
a2 * y(k) + a1 * y(k - 1) + a0 * y(k - 2) ~ b2 * u(k) + b1 * u(k - 1) + b0 * u(k - 2)
]
Expand Down Expand Up @@ -128,10 +128,10 @@ requires specification of the initial condition for both `x(k-1)` and `x(k-2)`.
Multi-rate systems are easy to model using multiple different clocks. The following set of equations is valid, and defines *two different discrete-time partitions*, each with its own clock:

```julia
yd1 ~ Sample(t, dt1)(y)
ud1 ~ kp * (Sample(t, dt1)(r) - yd1)
yd2 ~ Sample(t, dt2)(y)
ud2 ~ kp * (Sample(t, dt2)(r) - yd2)
yd1 ~ Sample(dt1)(y)
ud1 ~ kp * (Sample(dt1)(r) - yd1)
yd2 ~ Sample(dt2)(y)
ud2 ~ kp * (Sample(dt2)(r) - yd2)
```

`yd1` and `ud1` belong to the same clock which ticks with an interval of `dt1`, while `yd2` and `ud2` belong to a different clock which ticks with an interval of `dt2`. The two clocks are *not synchronized*, i.e., they are not *guaranteed* to tick at the same point in time, even if one tick interval is a rational multiple of the other. Mechanisms for synchronization of clocks are not yet implemented.
Expand All @@ -148,7 +148,7 @@ using ModelingToolkit: t_nounits as t
using ModelingToolkit: D_nounits as D
dt = 0.5 # Sample interval
@variables r(t)
clock = Clock(t, dt)
clock = Clock(dt)
k = ShiftIndex(clock)
function plant(; name)
Expand Down
5 changes: 3 additions & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ using SciMLStructures
using Compat
using AbstractTrees
using DiffEqBase, SciMLBase, ForwardDiff
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
PeriodicClock, Clock, SolverStepClock, Continuous
using Distributed
import JuliaFormatter
using MLStyle
Expand Down Expand Up @@ -272,6 +273,6 @@ export debug_system
#export has_discrete_domain, has_continuous_domain
#export is_discrete_domain, is_continuous_domain, is_hybrid_domain
export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime
export Clock #, InferredDiscrete,
export Clock, SolverStepClock, TimeDomain

end # module
101 changes: 33 additions & 68 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
abstract type TimeDomain end
abstract type AbstractDiscrete <: TimeDomain end
module InferredClock

Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d)
export InferredTimeDomain

struct Inferred <: TimeDomain end
struct InferredDiscrete <: AbstractDiscrete end
struct Continuous <: TimeDomain end
using Expronicon.ADT: @adt, @match
using SciMLBase: TimeDomain

Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain
@adt InferredTimeDomain begin
Inferred
InferredDiscrete
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

end

using .InferredClock

struct VariableTimeDomain end
Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain

is_concrete_time_domain(::TimeDomain) = true
is_concrete_time_domain(_) = false

"""
is_continuous_domain(x)
Expand All @@ -16,15 +29,15 @@ true if `x` contains only continuous-domain signals.
See also [`has_continuous_domain`](@ref)
"""
function is_continuous_domain(x)
issym(x) && return getmetadata(x, TimeDomain, false) isa Continuous
issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous
!has_discrete_domain(x) && has_continuous_domain(x)
end

function get_time_domain(x)
if iscall(x) && operation(x) isa Operator
output_timedomain(x)
else
getmetadata(x, TimeDomain, nothing)
getmetadata(x, VariableTimeDomain, nothing)
end
end
get_time_domain(x::Num) = get_time_domain(value(x))
Expand All @@ -37,14 +50,14 @@ Determine if variable `x` has a time-domain attributed to it.
function has_time_domain(x::Symbolic)
# getmetadata(x, Continuous, nothing) !== nothing ||
# getmetadata(x, Discrete, nothing) !== nothing
getmetadata(x, TimeDomain, nothing) !== nothing
getmetadata(x, VariableTimeDomain, nothing) !== nothing
end
has_time_domain(x::Num) = has_time_domain(value(x))
has_time_domain(x) = false

for op in [Differential]
@eval input_timedomain(::$op, arg = nothing) = Continuous()
@eval output_timedomain(::$op, arg = nothing) = Continuous()
@eval input_timedomain(::$op, arg = nothing) = Continuous
@eval output_timedomain(::$op, arg = nothing) = Continuous
end

"""
Expand Down Expand Up @@ -83,12 +96,17 @@ true if `x` contains only discrete-domain signals.
See also [`has_discrete_domain`](@ref)
"""
function is_discrete_domain(x)
if hasmetadata(x, TimeDomain) || issym(x)
return getmetadata(x, TimeDomain, false) isa AbstractDiscrete
if hasmetadata(x, VariableTimeDomain) || issym(x)
return is_discrete_time_domain(getmetadata(x, VariableTimeDomain, false))
end
!has_discrete_domain(x) && has_continuous_domain(x)
end

sampletime(c) = @match c begin
PeriodicClock(dt, _...) => dt
_ => nothing
end

struct ClockInferenceException <: Exception
msg::Any
end
Expand All @@ -97,57 +115,4 @@ function Base.showerror(io::IO, cie::ClockInferenceException)
print(io, "ClockInferenceException: ", cie.msg)
end

abstract type AbstractClock <: AbstractDiscrete end

"""
Clock <: AbstractClock
Clock([t]; dt)
The default periodic clock with independent variables `t` and tick interval `dt`.
If `dt` is left unspecified, it will be inferred (if possible).
"""
struct Clock <: AbstractClock
"Independent variable"
t::Union{Nothing, Symbolic}
"Period"
dt::Union{Nothing, Float64}
Clock(t::Union{Num, Symbolic}, dt = nothing) = new(value(t), dt)
Clock(t::Nothing, dt = nothing) = new(t, dt)
end
Clock(dt::Real) = Clock(nothing, dt)
Clock() = Clock(nothing, nothing)

sampletime(c) = isdefined(c, :dt) ? c.dt : nothing
Base.hash(c::Clock, seed::UInt) = hash(c.dt, seed 0x953d7a9a18874b90)
function Base.:(==)(c1::Clock, c2::Clock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) && c1.dt == c2.dt
end

is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous}

"""
SolverStepClock <: AbstractClock
SolverStepClock()
SolverStepClock(t)
A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). This clock **does generally not have equidistant tick intervals**, instead, the tick interval depends on the adaptive step-size selection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`.
Due to possibly non-equidistant tick intervals, this clock should typically not be used with discrete-time systems that assume a fixed sample time, such as PID controllers and digital filters.
"""
struct SolverStepClock <: AbstractClock
"Independent variable"
t::Union{Nothing, Symbolic}
"Period"
SolverStepClock(t::Union{Num, Symbolic}) = new(value(t))
end
SolverStepClock() = SolverStepClock(nothing)

Base.hash(c::SolverStepClock, seed::UInt) = seed 0x953d7b9a18874b91
function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))
end

struct IntegerSequence <: AbstractClock
t::Union{Nothing, Symbolic}
IntegerSequence(t::Union{Num, Symbolic}) = new(value(t))
end
struct IntegerSequence end
53 changes: 29 additions & 24 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ $(TYPEDEF)
Represents a sample operator. A discrete-time signal is created by sampling a continuous-time signal.
# Constructors
`Sample(clock::TimeDomain = InferredDiscrete())`
`Sample([t], dt::Real)`
`Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete)`
`Sample(dt::Real)`
`Sample(x::Num)`, with a single argument, is shorthand for `Sample()(x)`.
Expand All @@ -100,16 +100,23 @@ julia> using Symbolics
julia> t = ModelingToolkit.t_nounits
julia> Δ = Sample(t, 0.01)
julia> Δ = Sample(0.01)
(::Sample) (generic function with 2 methods)
```
"""
struct Sample <: Operator
clock::Any
Sample(clock::TimeDomain = InferredDiscrete()) = new(clock)
Sample(t, dt::Real) = new(Clock(t, dt))
Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete) = new(clock)
end

function Sample(arg::Real)
arg = unwrap(arg)
if symbolic_type(arg) == NotSymbolic()
Sample(Clock(arg))
else
Sample()(arg)
end
end
Sample(x) = Sample()(x)
(D::Sample)(x) = Term{symtype(x)}(D, Any[x])
(D::Sample)(x::Num) = Num(D(value(x)))
SymbolicUtils.promote_symtype(::Sample, x) = x
Expand Down Expand Up @@ -176,15 +183,18 @@ julia> x(k) # no shift
x(t)
julia> x(k+1) # shift
Shift(t, 1)(x(t))
Shift(1)(x(t))
```
"""
struct ShiftIndex
clock::TimeDomain
clock::Union{InferredTimeDomain, TimeDomain, IntegerSequence}
steps::Int
ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps)
ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps)
ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps)
function ShiftIndex(
clock::Union{TimeDomain, InferredTimeDomain, IntegerSequence} = Inferred, steps::Int = 0)
new(clock, steps)
end
ShiftIndex(dt::Real, steps::Int = 0) = new(Clock(dt), steps)
ShiftIndex(::Num, steps::Int) = new(IntegerSequence(), steps)
end

function (xn::Num)(k::ShiftIndex)
Expand All @@ -197,18 +207,13 @@ function (xn::Num)(k::ShiftIndex)
args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t)
length(args) == 1 ||
error("Cannot shift an expression with multiple independent variables $x.")
t = args[]
if hasfield(typeof(clock), :t)
isequal(t, clock.t) ||
error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)")
end

# d, _ = propagate_time_domain(xn)
# if d != clock # this is only required if the variable has another clock
# xn = Sample(t, clock)(xn)
# end
# QUESTION: should we return a variable with time domain set to k.clock?
xn = setmetadata(xn, TimeDomain, k.clock)
xn = setmetadata(xn, VariableTimeDomain, k.clock)
if steps == 0
return xn # x(k) needs no shift operator if the step of k is 0
end
Expand All @@ -221,37 +226,37 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
"""
input_timedomain(op::Operator)
Return the time-domain type (`Continuous()` or `Discrete()`) that `op` operates on.
Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete()
InferredDiscrete
end

"""
output_timedomain(op::Operator)
Return the time-domain type (`Continuous()` or `Discrete()`) that `op` results in.
Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in.
"""
function output_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete()
InferredDiscrete
end

input_timedomain(::Sample, arg = nothing) = Continuous()
input_timedomain(::Sample, arg = nothing) = Continuous
output_timedomain(s::Sample, arg = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete() # the Hold accepts any discrete
InferredDiscrete # the Hold accepts any discrete
end
output_timedomain(::Hold, arg = nothing) = Continuous()
output_timedomain(::Hold, arg = nothing) = Continuous

sampletime(op::Sample, arg = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, arg = nothing) = sampletime(op.clock)
Expand Down
Loading

0 comments on commit 171c43e

Please sign in to comment.