Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: major cleanup of *Problem construction #3121

Merged
merged 7 commits into from
Oct 16, 2024
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ include("systems/abstractsystem.jl")
include("systems/model_parsing.jl")
include("systems/connectors.jl")
include("systems/callbacks.jl")
include("systems/problem_utils.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/diffeqs/odesystem.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ function Base.eltype(::Type{<:TreeIterator{ModelingToolkit.AbstractSystem}})
end

function check_array_equations_unknowns(eqs, dvs)
if any(eq -> Symbolics.isarraysymbolic(eq.lhs), eqs)
if any(eq -> eq isa Equation && Symbolics.isarraysymbolic(eq.lhs), eqs)
throw(ArgumentError("The system has array equations. Call `structural_simplify` to handle such equations or scalarize them manually."))
end
if any(x -> Symbolics.isarraysymbolic(x), dvs)
Expand Down
313 changes: 9 additions & 304 deletions src/systems/diffeqs/abstractodesystem.jl

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
end
f, u0, p = process_DEProblem(
f, u0, p = process_SciMLProblem(
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
kwargs...)
cbs = process_events(sys; callback, kwargs...)
Expand Down Expand Up @@ -745,7 +745,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblemExpr`")
end
f, u0, p = process_DEProblem(SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
f, u0, p = process_SciMLProblem(
SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
kwargs...)
linenumbers = get(kwargs, :linenumbers, true)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
Expand Down
60 changes: 16 additions & 44 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,55 +236,25 @@ function generate_function(
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
end

function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
linenumbers = true, parallel = SerialForm(),
use_union = false,
tofloat = !use_union,
eval_expression = false, eval_module = @__MODULE__,
kwargs...)
function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
iv = get_iv(sys)
eqs = equations(sys)
dvs = unknowns(sys)
ps = parameters(sys)

if eltype(u0map) <: Number
u0map = unknowns(sys) .=> vec(u0map)
end
if u0map === nothing || isempty(u0map)
u0map = Dict()
end

trueu0map = Dict()
for (k, v) in u0map
k = unwrap(k)
updated = AnyDict()
for k in collect(keys(u0map))
v = u0map[k]
if !((op = operation(k)) isa Shift)
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
end
trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
end
defs = ModelingToolkit.get_defaults(sys)
for var in dvs
if (op = operation(var)) isa Shift && !haskey(trueu0map, var)
root = arguments(var)[1]
haskey(defs, root) || error("Initial condition for $var not provided.")
trueu0map[var] = defs[root]
end
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
end
if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, trueu0map, parammap)
p = MTKParameters(sys, parammap, trueu0map)
else
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
for var in unknowns(sys)
op = operation(var)
op isa Shift || continue
haskey(updated, var) && continue
root = first(arguments(var))
haskey(defs, root) || error("Initial condition for $var not provided.")
updated[var] = defs[root]
end

check_eqs_u0(eqs, dvs, u0; kwargs...)

f = constructor(sys, dvs, ps, u0;
linenumbers = linenumbers, parallel = parallel,
syms = Symbol.(dvs), paramsyms = Symbol.(ps),
eval_expression = eval_expression, eval_module = eval_module,
kwargs...)
return f, u0, p
return updated
end

"""
Expand All @@ -307,7 +277,9 @@ function SciMLBase.DiscreteProblem(
eqs = equations(sys)
iv = get_iv(sys)

f, u0, p = process_DiscreteProblem(
u0map = to_varmap(u0map, dvs)
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
f, u0, p = process_SciMLProblem(
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
u0 = f(u0, p, tspan[1])
DiscreteProblem(f, u0, tspan, p; kwargs...)
Expand Down
41 changes: 6 additions & 35 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
Expand Down Expand Up @@ -399,16 +387,9 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
end
dvs = unknowns(sys)
ps = parameters(sys)
defs = defaults(sys)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
# identity function to make syms works
quote
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
Expand Down Expand Up @@ -454,19 +435,9 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

Expand Down
9 changes: 8 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function generate_initializesystem(sys::ODESystem;
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
for (y, x) in u0map
function process_u0map_with_dummysubs(y, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why create a closure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It reduces code duplication, since array symbolics need to be treated as arrays of scalars below

y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
Expand All @@ -53,6 +53,13 @@ function generate_initializesystem(sys::ODESystem;
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
for (y, x) in u0map
if Symbolics.isarraysymbolic(y)
process_u0map_with_dummysubs.(collect(y), collect(x))
else
process_u0map_with_dummysubs(y, x)
end
end
end

# 2) process other variables
Expand Down
40 changes: 5 additions & 35 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
end

function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing, p = nothing;
ps = parameters(sys), u0 = nothing; p = nothing,
version = nothing,
jac = false,
eval_expression = false,
Expand Down Expand Up @@ -408,36 +408,6 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
!linenumbers ? Base.remove_linenums!(ex) : ex
end

function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, parammap;
version = nothing,
jac = false,
checkbounds = false, sparse = false,
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = false,
eval_module = @__MODULE__,
use_union = false,
tofloat = !use_union,
kwargs...)
eqs = equations(sys)
dvs = unknowns(sys)
ps = parameters(sys)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, u0map, parammap)
check_eqs_u0(eqs, dvs, u0; kwargs...)
p = MTKParameters(sys, parammap, u0map)
else
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
check_eqs_u0(eqs, dvs, u0; kwargs...)
end

f = constructor(sys, dvs, ps, u0, p; jac = jac, checkbounds = checkbounds,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression, eval_module = eval_module,
kwargs...)
return f, u0, p
end

"""
```julia
DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
Expand All @@ -461,7 +431,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
end
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
Expand Down Expand Up @@ -490,7 +460,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearLeastSquaresProblem`")
end
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
Expand Down Expand Up @@ -523,7 +493,7 @@ function NonlinearProblemExpr{iip}(sys::NonlinearSystem, u0map,
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`")
end
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
check_length, kwargs...)
linenumbers = get(kwargs, :linenumbers, true)

Expand Down Expand Up @@ -563,7 +533,7 @@ function NonlinearLeastSquaresProblemExpr{iip}(sys::NonlinearSystem, u0map,
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`")
end
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
check_length, kwargs...)
linenumbers = get(kwargs, :linenumbers, true)

Expand Down
Loading
Loading