Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,19 @@ function JumpSystem(eqs, iv, unknowns, ps;
# this and the treatment of continuous events are the only part
# unique to JumpSystems
eqs = scalarize.(eqs)
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
ap = ArrayPartition(
MassActionJump[], ConstantRateJump[], VariableRateJump[], Equation[])
for eq in eqs
if eq isa MassActionJump
push!(ap.x[1], eq)
elseif eq isa ConstantRateJump
push!(ap.x[2], eq)
elseif eq isa VariableRateJump
push!(ap.x[3], eq)
elseif eq isa Equation
push!(ap.x[4], eq)
else
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, VariableRateJumps, or Equations.")
end
end

Expand Down Expand Up @@ -245,6 +248,7 @@ end
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
has_equations(js::JumpSystem) = !isempty(equations(js).x[4])

function generate_rate_function(js::JumpSystem, rate)
consts = collect_constants(rate)
Expand Down Expand Up @@ -390,6 +394,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

if has_equations(sys)
error("The passed in JumpSystem contains `Equations`, please use a problem type that supports equations such as such as ODEProblem.")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
Expand Down Expand Up @@ -473,19 +482,31 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
use_union = false,
eval_expression = false,
eval_module = @__MODULE__,
check_length = false,
kwargs...)
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

f = (du, u, p, t) -> (du .= 0; nothing)
df = ODEFunction(f; sys, observed = observedfun)
ODEProblem(df, u0, tspan, p; kwargs...)
# forward everything to be an ODESystem but the jumps
if has_equations(sys)
osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys);
observed = observed(sys), name = nameof(sys), description = description(sys),
systems = get_systems(sys), defaults = defaults(sys),
discrete_events = discrete_events(sys),
parameter_dependencies = parameter_dependencies(sys),
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
osys = complete(osys)
return ODEProblem(osys, u0map, tspan, parammap; check_length, kwargs...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed like the best way to handle generating the underlying ODEProblem, but happy to consider alternative approaches if anyone has feedback.

else
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; check_length, kwargs...)
end
end

"""
Expand Down Expand Up @@ -521,8 +542,8 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
for j in eqs.x[2]]
vrjs = VariableRateJump[assemble_vrj(js, j, unknowntoid; eval_expression, eval_module)
for j in eqs.x[3]]
((prob isa DiscreteProblem) && !isempty(vrjs)) &&
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
((prob isa DiscreteProblem) && (!isempty(vrjs) || has_equations(js))) &&
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps and/or coupled differential equations.")
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)

# dep graphs are only for constant rate jumps
Expand Down
22 changes: 22 additions & 0 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,25 @@ let
@test issetequal(us, [x5])
@test issetequal(ps, [p5])
end

# PDMP test
let
@variables X(t) Y(t)
@parameters k1 k2
rate1 = k1 * X
affect1! = [X ~ X - 1]
rate2 = k1
affect2! = [Y ~ Y + 1]
eqs = [D(X) ~ k2, D(Y) ~ -k2/100*Y]
vrj1 = VariableRateJump(rate1, affect1!)
vrj2 = VariableRateJump(rate2, affect2!)
@named jsys = JumpSystem([vrj1, vrj2, eqs[1], eqs[2]], t, [X, Y], [k1, k2])
jsys = complete(jsys)
u0 = [X => 0.0, Y => 0.0]
p = [k1 => 1.0, k2 => 20.0]
tspan = (0.0, 20.0)
oprob = ODEProblem(jsys, u0, tspan, p)
jprob = JumpProblem(jsys, oprob; rng)
sol = solve(jprob, Tsit5())

end