Skip to content

Commit 9476141

Browse files
feat: implement OptimizationProblem, OptimizationFunction for System
1 parent c55c6c4 commit 9476141

File tree

3 files changed

+274
-0
lines changed

3 files changed

+274
-0
lines changed

Diff for: src/problems/compatibility.jl

+18
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ function check_no_cost(sys::System, T)
6161
end
6262
end
6363

64+
function check_has_cost(sys::System, T)
65+
cost = ModelingToolkit.cost(sys)
66+
if _iszero(cost)
67+
throw(SystemCompatibilityError("""
68+
A system without cost cannot be used to construct a `$T`.
69+
"""))
70+
end
71+
end
72+
6473
function check_no_constraints(sys::System, T)
6574
if !isempty(constraints(sys))
6675
throw(SystemCompatibilityError("""
@@ -140,3 +149,12 @@ function check_is_implicit(sys::System, T, altT)
140149
"""))
141150
end
142151
end
152+
153+
function check_no_equations(sys::System, T)
154+
if !isempty(equations(sys))
155+
throw(SystemCompatibilityError("""
156+
A system with equations cannot be used to construct a `$T`. Consider turning the
157+
equations into constraints instead.
158+
"""))
159+
end
160+
end

Diff for: src/problems/optimizationproblem.jl

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
function SciMLBase.OptimizationFunction(sys::System, args...; kwargs...)
2+
return OptimizationFunction{true}(sys, args...; kwargs...)
3+
end
4+
5+
function SciMLBase.OptimizationFunction{iip}(sys::System,
6+
_d = nothing, u0 = nothing, p = nothing; grad = false, hess = false,
7+
sparse = false, cons_j = false, cons_h = false, cons_sparse = false,
8+
linenumbers = true, eval_expression = false, eval_module = @__MODULE__,
9+
simplify = false, check_compatibility = true, checkbounds = false, cse = true,
10+
kwargs...) where {iip}
11+
check_complete(sys, OptimizationFunction)
12+
check_compatibility && check_compatible_system(OptimizationFunction, sys)
13+
dvs = unknowns(sys)
14+
ps = parameters(sys)
15+
cstr = constraints(sys)
16+
17+
f = generate_cost(sys; expression = Val{false}, eval_expression,
18+
eval_module, checkbounds, cse, kwargs...)
19+
20+
if grad
21+
_grad = generate_cost_gradient(sys; expression = Val{false}, eval_expression,
22+
eval_module, checkbounds, cse, kwargs...)
23+
else
24+
_grad = nothing
25+
end
26+
if hess
27+
_hess, hess_prototype = generate_cost_hessian(
28+
sys; expression = Val{false}, eval_expression, eval_module,
29+
checkbounds, cse, sparse, simplify, return_sparsity = true, kwargs...)
30+
else
31+
_hess = hess_prototype = nothing
32+
end
33+
if isempty(cstr)
34+
cons = lcons = ucons = _cons_j = cons_jac_prototype = _cons_h = nothing
35+
cons_hess_prototype = cons_expr = nothing
36+
else
37+
cons = generate_cons(sys; expression = Val{false}, eval_expression,
38+
eval_module, checkbounds, cse, kwargs...)
39+
if cons_j
40+
_cons_j, cons_jac_prototype = generate_constraint_jacobian(
41+
sys; expression = Val{false}, eval_expression, eval_module, checkbounds,
42+
cse, simplify, sparse = cons_sparse, return_sparsity = true, kwargs...)
43+
else
44+
_cons_j = cons_jac_prototype = nothing
45+
end
46+
if cons_h
47+
_cons_h, cons_hess_prototype = generate_constraint_hessian(
48+
sys; expression = Val{false}, eval_expression, eval_module, checkbounds,
49+
cse, simplify, sparse = cons_sparse, return_sparsity = true, kwargs...)
50+
else
51+
_cons_h = cons_hess_prototype = nothing
52+
end
53+
cons_expr = toexpr.(subs_constants(cstr))
54+
end
55+
56+
obj_expr = subs_constants(cost(sys))
57+
58+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds, cse)
59+
60+
return OptimizationFunction{iip}(f, SciMLBase.NoAD();
61+
sys = sys,
62+
grad = _grad,
63+
hess = _hess,
64+
hess_prototype = hess_prototype,
65+
cons = cons,
66+
cons_j = _cons_j,
67+
cons_jac_prototype = cons_jac_prototype,
68+
cons_h = _cons_h,
69+
cons_hess_prototype = cons_hess_prototype,
70+
cons_expr = cons_expr,
71+
expr = obj_expr,
72+
observed = observedfun)
73+
end
74+
75+
function SciMLBase.OptimizationProblem(sys::System, args...; kwargs...)
76+
return OptimizationProblem{true}(sys, args...; kwargs...)
77+
end
78+
79+
function SciMLBase.OptimizationProblem{iip}(
80+
sys::System, u0map, parammap = SciMLBase.NullParameters(); lb = nothing, ub = nothing,
81+
check_compatibility = true, kwargs...) where {iip}
82+
check_complete(sys, OptimizationProblem)
83+
check_compatibility && check_compatible_system(OptimizationProblem, sys)
84+
85+
f, u0, p = process_SciMLProblem(OptimizationFunction{iip}, sys, u0map, parammap;
86+
check_compatibility, tofloat = false, check_length = false, kwargs...)
87+
88+
dvs = unknowns(sys)
89+
int = symtype.(unwrap.(dvs)) .<: Integer
90+
if lb === nothing && ub === nothing
91+
lb = first.(getbounds.(dvs))
92+
ub = last.(getbounds.(dvs))
93+
isboolean = symtype.(unwrap.(dvs)) .<: Bool
94+
lb[isboolean] .= 0
95+
ub[isboolean] .= 1
96+
else
97+
xor(isnothing(lb), isnothing(ub)) &&
98+
throw(ArgumentError("Expected both `lb` and `ub` to be supplied"))
99+
!isnothing(lb) && length(lb) != length(dvs) &&
100+
throw(ArgumentError("Expected both `lb` to be of the same length as the vector of optimization variables"))
101+
!isnothing(ub) && length(ub) != length(dvs) &&
102+
throw(ArgumentError("Expected both `ub` to be of the same length as the vector of optimization variables"))
103+
end
104+
105+
ps = parameters(sys)
106+
defs = merge(defaults(sys), to_varmap(parammap, ps), to_varmap(u0map, dvs))
107+
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false)
108+
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false)
109+
110+
if !isnothing(lb) && all(lb .== -Inf) && !isnothing(ub) && all(ub .== Inf)
111+
lb = nothing
112+
ub = nothing
113+
end
114+
115+
cstr = constraints(sys)
116+
if isempty(cstr)
117+
lcons = ucons = nothing
118+
else
119+
lcons = fill(-Inf, length(cstr))
120+
ucons = zeros(length(cstr))
121+
lcons[findall(Base.Fix2(isa, Equation), cstr)] .= 0.0
122+
end
123+
124+
kwargs = process_kwargs(sys; kwargs...)
125+
# Call `remake` so it runs initialization if it is trivial
126+
return remake(OptimizationProblem{iip}(f, u0, p; lb, ub, int, lcons, ucons, kwargs...))
127+
end
128+
129+
function check_compatible_system(
130+
T::Union{Type{OptimizationFunction}, Type{OptimizationProblem}}, sys::System)
131+
check_time_independent(sys, T)
132+
check_not_dde(sys)
133+
check_has_cost(sys, T)
134+
check_no_jumps(sys, T)
135+
check_no_noise(sys, T)
136+
check_no_equations(sys, T)
137+
end

Diff for: src/systems/codegen.jl

+119
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,122 @@ function generate_boundary_conditions(sys::System, u0, u0_idxs, t0; expression =
369369
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
370370
return GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
371371
end
372+
373+
function generate_cost(sys::System; expression = Val{true}, eval_expression = false,
374+
eval_module = @__MODULE__, kwargs...)
375+
obj = cost(sys)
376+
dvs = unknowns(sys)
377+
ps = reorder_parameters(sys)
378+
res = build_function_wrapper(sys, obj, dvs, ps...; expression = Val{true}, kwargs...)
379+
if expression == Val{true}
380+
return res
381+
end
382+
f_oop = eval_or_rgf(res; eval_expression, eval_module)
383+
return GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, nothing)
384+
end
385+
386+
function generate_cost_gradient(
387+
sys::System; expression = Val{true}, eval_expression = false,
388+
eval_module = @__MODULE__, simplify = false, kwargs...)
389+
obj = cost(sys)
390+
dvs = unknowns(sys)
391+
ps = reorder_parameters(sys)
392+
exprs = Symbolics.gradient(obj, dvs; simplify)
393+
res = build_function_wrapper(sys, exprs, dvs, ps...; expression = Val{true}, kwargs...)
394+
if expression == Val{true}
395+
return res
396+
end
397+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
398+
return GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
399+
end
400+
401+
function generate_cost_hessian(
402+
sys::System; expression = Val{true}, eval_expression = false,
403+
eval_module = @__MODULE__, simplify = false,
404+
sparse = false, return_sparsity = false, kwargs...)
405+
obj = cost(sys)
406+
dvs = unknowns(sys)
407+
ps = reorder_parameters(sys)
408+
sparsity = nothing
409+
if sparse
410+
exprs = Symbolics.sparsehessian(obj, dvs; simplify)::AbstractSparseArray
411+
sparsity = similar(exprs, Float64)
412+
else
413+
exprs = Symbolics.hessian(obj, dvs; simplify)
414+
end
415+
res = build_function_wrapper(sys, exprs, dvs, ps...; expression = Val{true}, kwargs...)
416+
if expression == Val{true}
417+
return return_sparsity ? (res, sparsity) : res
418+
end
419+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
420+
fn = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
421+
return return_sparsity ? (fn, sparsity) : fn
422+
end
423+
424+
function canonical_constraints(sys::System)
425+
return map(constraints(sys)) do cstr
426+
Symbolics.canonical_form(cstr).lhs
427+
end
428+
end
429+
430+
function generate_cons(sys::System; expression = Val{true}, eval_expression = false,
431+
eval_module = @__MODULE__, kwargs...)
432+
cons = canonical_constraints(sys)
433+
dvs = unknowns(sys)
434+
ps = reorder_parameters(sys)
435+
res = build_function_wrapper(sys, cons, dvs, ps...; expression = Val{true}, kwargs...)
436+
if expression == Val{true}
437+
return res
438+
end
439+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
440+
fn = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
441+
return fn
442+
end
443+
444+
function generate_constraint_jacobian(
445+
sys::System; expression = Val{true}, eval_expression = false,
446+
eval_module = @__MODULE__, return_sparsity = false,
447+
simplify = false, sparse = false, kwargs...)
448+
cons = canonical_constraints(sys)
449+
dvs = unknowns(sys)
450+
ps = reorder_parameters(sys)
451+
sparsity = nothing
452+
if sparse
453+
jac = Symbolics.sparsejacobian(cons, dvs; simplify)::AbstractSparseArray
454+
sparsity = similar(jac, Float64)
455+
else
456+
jac = Symbolics.jacobian(cons, dvs; simplify)
457+
end
458+
res = build_function_wrapper(sys, jac, dvs, ps...; expression = Val{true}, kwargs...)
459+
if expression == Val{true}
460+
return return_sparsity ? (res, sparsity) : res
461+
end
462+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
463+
fn = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
464+
return return_sparsity ? (fn, sparsity) : fn
465+
end
466+
467+
function generate_constraint_hessian(
468+
sys::System; expression = Val{true}, eval_expression = false,
469+
eval_module = @__MODULE__, return_sparsity = false,
470+
simplify = false, sparse = false, kwargs...)
471+
cons = canonical_constraints(sys)
472+
dvs = unknowns(sys)
473+
ps = reorder_parameters(sys)
474+
sparsity = nothing
475+
if sparse
476+
hess = map(cons) do cstr
477+
Symbolics.sparsehessian(cstr, dvs; simplify)::AbstractSparseArray
478+
end
479+
sparsity = similar.(hess, Float64)
480+
else
481+
hess = [Symbolics.hessian(cstr, dvs; simplify) for cstr in cons]
482+
end
483+
res = build_function_wrapper(sys, hess, dvs, ps...; expression = Val{true}, kwargs...)
484+
if expression == Val{true}
485+
return return_sparsity ? (res, sparsity) : res
486+
end
487+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
488+
fn = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
489+
return return_sparsity ? (fn, sparsity) : fn
490+
end

0 commit comments

Comments
 (0)