Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] - Support ODEs as equations in JumpSystem #3181

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
20 changes: 11 additions & 9 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#################################### system operations #####################################
get_continuous_events(sys::AbstractSystem) = SymbolicContinuousCallback[]
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
function get_continuous_events(sys::AbstractSystem)
has_continuous_events(sys) || return SymbolicContinuousCallback[]
getfield(sys, :continuous_events)
end
Comment on lines +3 to +6
Copy link
Member Author

@isaacsas isaacsas Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the changes below are to allow non-AbstractODESystems to also use continuous events (JumpSystems in this case).


has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events)
function get_discrete_events(sys::AbstractSystem)
Expand Down Expand Up @@ -675,7 +677,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
end
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
function generate_rootfinding_callback(sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
Expand All @@ -686,7 +688,7 @@ Generate a single rootfinding callback; this happens if there is only one equati
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
"""
function generate_single_rootfinding_callback(
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
eq, cb, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
Expand Down Expand Up @@ -728,7 +730,7 @@ function generate_single_rootfinding_callback(
end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
cbs, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
reinitialization = SciMLBase.CheckInit(), kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
Expand Down Expand Up @@ -840,7 +842,7 @@ end
"""
Compile a single continuous callback affect function(s).
"""
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs)
eq_aff = affects(cb)
eq_neg_aff = affect_negs(cb)
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
Expand All @@ -857,7 +859,7 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
function generate_rootfinding_callback(cbs, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
Expand Down Expand Up @@ -1052,12 +1054,12 @@ merge_cb(x, ::Nothing) = x
merge_cb(x, y) = CallbackSet(x, y)

function process_events(sys; callback = nothing, kwargs...)
if has_continuous_events(sys)
if has_continuous_events(sys) && !isempty(continuous_events(sys))
contin_cb = generate_rootfinding_callback(sys; kwargs...)
else
contin_cb = nothing
end
if has_discrete_events(sys)
if has_discrete_events(sys) && !isempty(discrete_events(sys))
discrete_cb = generate_discrete_callbacks(sys; kwargs...)
else
discrete_cb = nothing
Expand Down
65 changes: 45 additions & 20 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
"""
connector_type::Any
"""
A `Vector{SymbolicContinuousCallback}` that model events.
The integrator will use root finding to guarantee that it steps at each zero crossing.
"""
continuous_events::Vector{SymbolicContinuousCallback}
"""
A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
analog to `SciMLBase.DiscreteCallback` that executes an affect when a given condition is
true at the end of an integration step. Note, one must make sure to call
Expand Down Expand Up @@ -120,8 +125,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem

function JumpSystem{U}(
tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, description,
systems,
defaults, connector_type, devents, parameter_dependencies,
systems, defaults, connector_type, cevents, devents, parameter_dependencies,
metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, isscheduled = false;
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
Expand All @@ -136,8 +140,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
end
new{U}(tag, ap, iv, unknowns, ps, var_to_name,
observed, name, description, systems, defaults,
connector_type, devents, parameter_dependencies, metadata, gui_metadata,
complete, index_cache, isscheduled)
connector_type, cevents, devents, parameter_dependencies, metadata,
gui_metadata, complete, index_cache, isscheduled)
end
end
function JumpSystem(tag, ap, iv, states, ps, var_to_name, args...; kwargs...)
Expand Down Expand Up @@ -194,26 +198,28 @@ function JumpSystem(eqs, iv, unknowns, ps;
# this and the treatment of continuous events are the only part
# unique to JumpSystems
eqs = scalarize.(eqs)
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
ap = ArrayPartition(
MassActionJump[], ConstantRateJump[], VariableRateJump[], Equation[])
for eq in eqs
if eq isa MassActionJump
push!(ap.x[1], eq)
elseif eq isa ConstantRateJump
push!(ap.x[2], eq)
elseif eq isa VariableRateJump
push!(ap.x[3], eq)
elseif eq isa Equation
push!(ap.x[4], eq)
else
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, VariableRateJumps, or Equations.")
end
end

(continuous_events === nothing) ||
error("JumpSystems currently only support discrete events.")
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)

JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,
defaults, connector_type, disc_callbacks, parameter_dependencies,
defaults, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies,
metadata, gui_metadata, checks = checks)
end

Expand Down Expand Up @@ -245,6 +251,7 @@ end
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
has_equations(js::JumpSystem) = !isempty(equations(js).x[4])

function generate_rate_function(js::JumpSystem, rate)
consts = collect_constants(rate)
Expand Down Expand Up @@ -281,7 +288,7 @@ function assemble_vrj(
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
eval_expression, eval_module)
VariableRateJump(rate, affect)
VariableRateJump(rate, affect; save_positions = vrj.save_positions)
end

function assemble_vrj_expr(js, vrj, unknowntoid)
Expand Down Expand Up @@ -390,6 +397,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

if has_equations(sys) || (!isempty(continuous_events(sys)))
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
Expand Down Expand Up @@ -478,14 +490,24 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

f = (du, u, p, t) -> (du .= 0; nothing)
df = ODEFunction(f; sys, observed = observedfun)
ODEProblem(df, u0, tspan, p; kwargs...)
# forward everything to be an ODESystem but the jumps and discrete events
if has_equations(sys)
osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys);
observed = observed(sys), name = nameof(sys), description = description(sys),
systems = get_systems(sys), defaults = defaults(sys),
parameter_dependencies = parameter_dependencies(sys),
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
osys = complete(osys)
return ODEProblem(osys, u0map, tspan, parammap; check_length = false, kwargs...)
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
else
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
end

"""
Expand Down Expand Up @@ -521,8 +543,11 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
for j in eqs.x[2]]
vrjs = VariableRateJump[assemble_vrj(js, j, unknowntoid; eval_expression, eval_module)
for j in eqs.x[3]]
((prob isa DiscreteProblem) && !isempty(vrjs)) &&
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
if prob isa DiscreteProblem
if (!isempty(vrjs) || has_equations(js) || !isempty(continuous_events(js)))
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps, coupled differential equations, or continuous events.")
end
end
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)

# dep graphs are only for constant rate jumps
Expand Down
36 changes: 35 additions & 1 deletion test/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ end
# basic VariableRateJump test
let
N = 1000 # number of simulations for testing solve accuracy
Random.seed!(rng, 1111)
Random.seed!(rng, 1111)
@variables A(t) B(t) C(t)
@parameters k
vrj = VariableRateJump(k * (sin(t) + 1), [A ~ A + 1, C ~ C + 2])
Expand Down Expand Up @@ -422,3 +422,37 @@ let
@test issetequal(us, [x5])
@test issetequal(ps, [p5])
end

# PDMP test
let
@variables X(t) Y(t)
@parameters k1 k2
vrj1 = VariableRateJump(k1 * X, [X ~ X - 1]; save_positions = (false, false))
vrj2 = VariableRateJump(k1, [Y ~ Y + 1]; save_positions = (false, false))
eqs = [D(X) ~ k2, D(Y) ~ -k2/100*Y]
@named jsys = JumpSystem([vrj1, vrj2, eqs[1], eqs[2]], t, [X, Y], [k1, k2])
jsys = complete(jsys)
X0 = 0.0; Y0 = 0.0
u0 = [X => X0, Y => Y0]
k1val = 1.0; k2val = 20.0
p = [k1 => k1val, k2 => k2val]
tspan = (0.0, 10.0)
oprob = ODEProblem(jsys, u0, tspan, p)
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))

times = range(0.0, tspan[2], length = 100)
Nsims = 4000
Xv = zeros(length(times))
Yv = copy(Xv)
for n in 1:Nsims
sol = solve(jprob, Tsit5(); saveat = times)
Xv .+= sol[1,:]
Yv .+= sol[2,:]
end
Xv ./= Nsims; Yv ./= Nsims;

Xact(t) = X0 * exp(-k1val * t) + (k2val / k1val) * (1 - exp(-k1val * t))
Yact(t) = Y0 * exp(-k2val/100 * t) + (k1val / (k2val/100)) * (1 - exp(-k2val/100 * t))
@test all(abs.(Xv .- Xact.(times)) .<= 0.05 .* Xv)
@test all(abs.(Yv .- Yact.(times)) .<= 0.05 .* Yv)
end
Loading