Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ODEs as equations in JumpSystem #3181

Merged
merged 22 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#################################### system operations #####################################
get_continuous_events(sys::AbstractSystem) = SymbolicContinuousCallback[]
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
function get_continuous_events(sys::AbstractSystem)
has_continuous_events(sys) || return SymbolicContinuousCallback[]
getfield(sys, :continuous_events)
end
isaacsas marked this conversation as resolved.
Show resolved Hide resolved

has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events)
function get_discrete_events(sys::AbstractSystem)
Expand Down Expand Up @@ -676,8 +678,8 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
end
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
function generate_rootfinding_callback(sys::AbstractTimeDependentSystem,
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
Expand All @@ -687,7 +689,7 @@ Generate a single rootfinding callback; this happens if there is only one equati
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
"""
function generate_single_rootfinding_callback(
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
eq, cb, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
Expand Down Expand Up @@ -729,7 +731,7 @@ function generate_single_rootfinding_callback(
end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
cbs, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
reinitialization = SciMLBase.CheckInit(), kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
Expand Down Expand Up @@ -841,7 +843,7 @@ end
"""
Compile a single continuous callback affect function(s).
"""
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs)
eq_aff = affects(cb)
eq_neg_aff = affect_negs(cb)
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
Expand All @@ -858,8 +860,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
function generate_rootfinding_callback(cbs, sys::AbstractTimeDependentSystem,
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
total_eqs = sum(num_eqs)
Expand Down Expand Up @@ -1053,12 +1055,12 @@ merge_cb(x, ::Nothing) = x
merge_cb(x, y) = CallbackSet(x, y)

function process_events(sys; callback = nothing, kwargs...)
if has_continuous_events(sys)
if has_continuous_events(sys) && !isempty(continuous_events(sys))
contin_cb = generate_rootfinding_callback(sys; kwargs...)
else
contin_cb = nothing
end
if has_discrete_events(sys)
if has_discrete_events(sys) && !isempty(discrete_events(sys))
discrete_cb = generate_discrete_callbacks(sys; kwargs...)
else
discrete_cb = nothing
Expand Down
65 changes: 45 additions & 20 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
"""
connector_type::Any
"""
A `Vector{SymbolicContinuousCallback}` that model events.
The integrator will use root finding to guarantee that it steps at each zero crossing.
"""
continuous_events::Vector{SymbolicContinuousCallback}
"""
A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
analog to `SciMLBase.DiscreteCallback` that executes an affect when a given condition is
true at the end of an integration step. Note, one must make sure to call
Expand Down Expand Up @@ -120,8 +125,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem

function JumpSystem{U}(
tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, description,
systems,
defaults, connector_type, devents, parameter_dependencies,
systems, defaults, connector_type, cevents, devents, parameter_dependencies,
metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, isscheduled = false;
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
Expand All @@ -136,8 +140,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
end
new{U}(tag, ap, iv, unknowns, ps, var_to_name,
observed, name, description, systems, defaults,
connector_type, devents, parameter_dependencies, metadata, gui_metadata,
complete, index_cache, isscheduled)
connector_type, cevents, devents, parameter_dependencies, metadata,
gui_metadata, complete, index_cache, isscheduled)
end
end
function JumpSystem(tag, ap, iv, states, ps, var_to_name, args...; kwargs...)
Expand Down Expand Up @@ -194,26 +198,28 @@ function JumpSystem(eqs, iv, unknowns, ps;
# this and the treatment of continuous events are the only part
# unique to JumpSystems
eqs = scalarize.(eqs)
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
ap = ArrayPartition(
MassActionJump[], ConstantRateJump[], VariableRateJump[], Equation[])
for eq in eqs
if eq isa MassActionJump
push!(ap.x[1], eq)
elseif eq isa ConstantRateJump
push!(ap.x[2], eq)
elseif eq isa VariableRateJump
push!(ap.x[3], eq)
elseif eq isa Equation
push!(ap.x[4], eq)
else
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, VariableRateJumps, or Equations.")
end
end

(continuous_events === nothing) ||
error("JumpSystems currently only support discrete events.")
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)

JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,
defaults, connector_type, disc_callbacks, parameter_dependencies,
defaults, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies,
metadata, gui_metadata, checks = checks)
end

Expand Down Expand Up @@ -245,6 +251,7 @@ end
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
has_equations(js::JumpSystem) = !isempty(equations(js).x[4])

function generate_rate_function(js::JumpSystem, rate)
consts = collect_constants(rate)
Expand Down Expand Up @@ -281,7 +288,7 @@ function assemble_vrj(
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
eval_expression, eval_module)
VariableRateJump(rate, affect)
VariableRateJump(rate, affect; save_positions = vrj.save_positions)
end

function assemble_vrj_expr(js, vrj, unknowntoid)
Expand Down Expand Up @@ -390,6 +397,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

if has_equations(sys) || (!isempty(continuous_events(sys)))
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
Expand Down Expand Up @@ -478,14 +490,24 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

f = (du, u, p, t) -> (du .= 0; nothing)
df = ODEFunction(f; sys, observed = observedfun)
ODEProblem(df, u0, tspan, p; kwargs...)
# forward everything to be an ODESystem but the jumps and discrete events
if has_equations(sys)
osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys);
observed = observed(sys), name = nameof(sys), description = description(sys),
systems = get_systems(sys), defaults = defaults(sys),
parameter_dependencies = parameter_dependencies(sys),
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
osys = complete(osys)
return ODEProblem(osys, u0map, tspan, parammap; check_length = false, kwargs...)
else
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
end

"""
Expand Down Expand Up @@ -521,8 +543,11 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
for j in eqs.x[2]]
vrjs = VariableRateJump[assemble_vrj(js, j, unknowntoid; eval_expression, eval_module)
for j in eqs.x[3]]
((prob isa DiscreteProblem) && !isempty(vrjs)) &&
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
if prob isa DiscreteProblem
if (!isempty(vrjs) || has_equations(js) || !isempty(continuous_events(js)))
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps, coupled differential equations, or continuous events.")
end
end
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)

# dep graphs are only for constant rate jumps
Expand Down
113 changes: 112 additions & 1 deletion test/jumpsystem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra
using Random, StableRNGs
using Random, StableRNGs, NonlinearSolve
using OrdinaryDiffEq
using ModelingToolkit: t_nounits as t, D_nounits as D
MT = ModelingToolkit
Expand Down Expand Up @@ -422,3 +422,114 @@ let
@test issetequal(us, [x5])
@test issetequal(ps, [p5])
end

# PDMP test
let
seed = 1111
Random.seed!(rng, seed)
@variables X(t) Y(t)
@parameters k1 k2
vrj1 = VariableRateJump(k1 * X, [X ~ X - 1]; save_positions = (false, false))
vrj2 = VariableRateJump(k1, [Y ~ Y + 1]; save_positions = (false, false))
eqs = [D(X) ~ k2, D(Y) ~ -k2 / 10 * Y]
@named jsys = JumpSystem([vrj1, vrj2, eqs[1], eqs[2]], t, [X, Y], [k1, k2])
jsys = complete(jsys)
X0 = 0.0
Y0 = 3.0
u0 = [X => X0, Y => Y0]
k1val = 1.0
k2val = 20.0
p = [k1 => k1val, k2 => k2val]
tspan = (0.0, 10.0)
oprob = ODEProblem(jsys, u0, tspan, p)
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))

times = range(0.0, tspan[2], length = 100)
Nsims = 4000
Xv = zeros(length(times))
Yv = zeros(length(times))
for n in 1:Nsims
sol = solve(jprob, Tsit5(); saveat = times, seed)
Xv .+= sol[1, :]
Yv .+= sol[2, :]
seed += 1
end
Xv ./= Nsims
Yv ./= Nsims

Xact(t) = X0 * exp(-k1val * t) + (k2val / k1val) * (1 - exp(-k1val * t))
function Yact(t)
Y0 * exp(-k2val / 10 * t) + (k1val / (k2val / 10)) * (1 - exp(-k2val / 10 * t))
end
@test all(abs.(Xv .- Xact.(times)) .<= 0.05 .* Xv)
@test all(abs.(Yv .- Yact.(times)) .<= 0.1 .* Yv)
end

# that mixes ODEs and jump types, and then contin events
let
seed = 1111
Random.seed!(rng, seed)
@variables X(t) Y(t)
@parameters α β
vrj = VariableRateJump(β * X, [X ~ X - 1]; save_positions = (false, false))
crj = ConstantRateJump(β * Y, [Y ~ Y - 1])
maj = MassActionJump(α, [0 => 1], [Y => 1])
eqs = [D(X) ~ α * (1 + Y)]
@named jsys = JumpSystem([maj, crj, vrj, eqs[1]], t, [X, Y], [α, β])
jsys = complete(jsys)
p = (α = 6.0, β = 2.0, X₀ = 2.0, Y₀ = 1.0)
u0map = [X => p.X₀, Y => p.Y₀]
pmap = [α => p.α, β => p.β]
tspan = (0.0, 20.0)
oprob = ODEProblem(jsys, u0map, tspan, pmap)
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
times = range(0.0, tspan[2], length = 100)
Nsims = 4000
Xv = zeros(length(times))
Yv = zeros(length(times))
for n in 1:Nsims
sol = solve(jprob, Tsit5(); saveat = times, seed)
Xv .+= sol[1, :]
Yv .+= sol[2, :]
seed += 1
end
Xv ./= Nsims
Yv ./= Nsims

function Yf(t, p)
local α, β, X₀, Y₀ = p
return (α / β) + (Y₀ - α / β) * exp(-β * t)
end
function Xf(t, p)
local α, β, X₀, Y₀ = p
return (α / β) + (α^2 / β^2) + α * (Y₀ - α / β) * t * exp(-β * t) +
(X₀ - α / β - α^2 / β^2) * exp(-β * t)
end
Xact = [Xf(t, p) for t in times]
Yact = [Yf(t, p) for t in times]
@test all(abs.(Xv .- Xact) .<= 0.05 .* Xv)
@test all(abs.(Yv .- Yact) .<= 0.05 .* Yv)

function affect!(integ, u, p, ctx)
savevalues!(integ, true)
terminate!(integ)
nothing
end
cevents = [t ~ 0.2] => (affect!, [], [], [], nothing)
@named jsys = JumpSystem([maj, crj, vrj, eqs[1]], t, [X, Y], [α, β];
continuous_events = cevents)
jsys = complete(jsys)
tspan = (0.0, 200.0)
oprob = ODEProblem(jsys, u0map, tspan, pmap)
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
Xsamp = 0.0
Nsims = 4000
for n in 1:Nsims
sol = solve(jprob, Tsit5(); saveat = tspan[2], seed)
@test sol.retcode == ReturnCode.Terminated
Xsamp += sol[1, end]
seed += 1
end
Xsamp /= Nsims
@test abs(Xsamp - Xf(0.2, p) < 0.05 * Xf(0.2, p))
end
Loading