Skip to content

Commit f7fb339

Browse files
SebastianM-Cclaude
andcommitted
Add codegen for BVProblems with cost functions
Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5faf961 commit f7fb339

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

lib/ModelingToolkitBase/src/problems/bvproblem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
checkbounds, time_dependent_init = false, expression, kwargs...
2323
)
2424

25+
fcost = generate_bvp_cost(sys; expression = Val{false}, wrap_gfw = Val{false},
26+
eval_expression, eval_module, cse, checkbounds)
27+
2528
stidxmap = Dict([v => i for (i, v) in enumerate(dvs)])
2629
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) :
2730
[stidxmap[k] for (k, v) in op if haskey(stidxmap, k)]
@@ -30,20 +33,25 @@
3033
wrap_gfw = Val{true}, cse, checkbounds
3134
)
3235

36+
n_controls = length(unbound_inputs(sys))
37+
f_prototype = n_controls > 0 ? zeros(eltype(u0), length(dvs) - n_controls) : nothing
38+
bcresid_prototype = zeros(eltype(u0), length(u0_idxs) + length(constraints(sys)))
39+
40+
bvpfn = BVPFunction{iip}(fode, fbc; cost = fcost, f_prototype, bcresid_prototype)
41+
3342
if (length(constraints(sys)) + length(op) > length(dvs))
3443
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
3544
end
3645

3746
kwargs = process_kwargs(sys; expression, kwargs...)
38-
args = (; fode, fbc, u0, tspan, p)
47+
args = (; bvpfn, u0, tspan, p)
3948

4049
return maybe_codegen_scimlproblem(expression, BVProblem{iip}, args; kwargs...)
4150
end
4251

4352
function check_compatible_system(T::Type{BVProblem}, sys::System)
4453
check_time_dependent(sys, T)
4554
check_not_dde(sys)
46-
check_no_cost(sys, T)
4755
check_no_jumps(sys, T)
4856
check_no_noise(sys, T)
4957
check_is_continuous(sys, T)

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,56 @@ function generate_cost(
699699
)
700700
end
701701
702+
"""
703+
$(TYPEDSIGNATURES)
704+
705+
Generate the cost function for a BVP [`System`](@ref). The generated function has the
706+
signature `cost(sol, p)` where `sol` is a solution interpolation object (callable as
707+
`sol(t)` to get state at time `t`) and `p` is the parameter object.
708+
709+
# Keyword Arguments
710+
711+
$GENERATE_X_KWARGS
712+
713+
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
714+
"""
715+
function generate_bvp_cost(
716+
sys::System; expression = Val{true}, wrap_gfw = Val{false},
717+
eval_expression = false, eval_module = @__MODULE__, cse = true,
718+
checkbounds = false, kwargs...
719+
)
720+
obj = cost(sys)
721+
_iszero(obj) && return nothing
722+
723+
iv = get_iv(sys)
724+
sts = unknowns(sys)
725+
ps = reorder_parameters(sys)
726+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
727+
728+
# Substitute x(t_val) -> BVP_SOLUTION(t_val)[idx] for all state evaluations
729+
costsubs = Dict()
730+
get_constraint_unknown_subs!(costsubs, [obj], stidxmap, iv, BVP_SOLUTION)
731+
obj = substitute(obj, costsubs)
732+
733+
# Build function with signature (sol, p) where sol = BVP_SOLUTION
734+
# The histfn mechanism replaces BVP_SOLUTION with the sol argument
735+
res = build_function_wrapper(
736+
sys, obj, ps...;
737+
expression = Val{true},
738+
p_start = 1, # sol goes before parameters
739+
p_end = length(ps),
740+
wrap_delays = true,
741+
histfn = (p, t) -> BVP_SOLUTION(t),
742+
histfn_symbolic = BVP_SOLUTION,
743+
cse, checkbounds, kwargs...
744+
)
745+
746+
# (2, 2, is_split) means: 2 args out-of-place, 2 original args, split status
747+
return maybe_compile_function(
748+
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
749+
)
750+
end
751+
702752
"""
703753
$(TYPEDSIGNATURES)
704754
@@ -1113,7 +1163,7 @@ end
11131163
11141164
Generates a function that computes the observed value(s) `ts` in the system `sys`, while making the assumption that there are no cycles in the equations.
11151165
1116-
## Arguments
1166+
## Arguments
11171167
- `sys`: The system for which to generate the function
11181168
- `ts`: The symbolic observed values whose value should be computed
11191169

0 commit comments

Comments
 (0)