Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions lib/ModelingToolkitBase/src/problems/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
checkbounds, time_dependent_init = false, expression, kwargs...
)

fcost = generate_bvp_cost(sys; expression = Val{false}, wrap_gfw = Val{false},
eval_expression, eval_module, cse, checkbounds)

stidxmap = Dict([v => i for (i, v) in enumerate(dvs)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) :
[stidxmap[k] for (k, v) in op if haskey(stidxmap, k)]
Expand All @@ -30,20 +33,25 @@
wrap_gfw = Val{true}, cse, checkbounds
)

n_controls = length(unbound_inputs(sys))
f_prototype = n_controls > 0 ? zeros(eltype(u0), length(dvs) - n_controls) : nothing
bcresid_prototype = zeros(eltype(u0), length(u0_idxs) + length(constraints(sys)))

bvpfn = BVPFunction{iip}(fode, fbc; cost = fcost, f_prototype, bcresid_prototype)

if (length(constraints(sys)) + length(op) > length(dvs))
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

kwargs = process_kwargs(sys; expression, kwargs...)
args = (; fode, fbc, u0, tspan, p)
args = (; bvpfn, u0, tspan, p)

return maybe_codegen_scimlproblem(expression, BVProblem{iip}, args; kwargs...)
end

function check_compatible_system(T::Type{BVProblem}, sys::System)
check_time_dependent(sys, T)
check_not_dde(sys)
check_no_cost(sys, T)
check_no_jumps(sys, T)
check_no_noise(sys, T)
check_is_continuous(sys, T)
Expand Down
52 changes: 51 additions & 1 deletion lib/ModelingToolkitBase/src/systems/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,56 @@ function generate_cost(
)
end

"""
$(TYPEDSIGNATURES)

Generate the cost function for a BVP [`System`](@ref). The generated function has the
signature `cost(sol, p)` where `sol` is a solution interpolation object (callable as
`sol(t)` to get state at time `t`) and `p` is the parameter object.

# Keyword Arguments

$GENERATE_X_KWARGS

All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
"""
function generate_bvp_cost(
sys::System; expression = Val{true}, wrap_gfw = Val{false},
eval_expression = false, eval_module = @__MODULE__, cse = true,
checkbounds = false, kwargs...
)
obj = cost(sys)
_iszero(obj) && return nothing

iv = get_iv(sys)
sts = unknowns(sys)
ps = reorder_parameters(sys)
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])

# Substitute x(t_val) -> BVP_SOLUTION(t_val)[idx] for all state evaluations
costsubs = Dict()
get_constraint_unknown_subs!(costsubs, [obj], stidxmap, iv, BVP_SOLUTION)
obj = substitute(obj, costsubs)

# Build function with signature (sol, p) where sol = BVP_SOLUTION
# The histfn mechanism replaces BVP_SOLUTION with the sol argument
res = build_function_wrapper(
sys, obj, ps...;
expression = Val{true},
p_start = 1, # sol goes before parameters
p_end = length(ps),
wrap_delays = true,
histfn = (p, t) -> BVP_SOLUTION(t),
histfn_symbolic = BVP_SOLUTION,
cse, checkbounds, kwargs...
)

# (2, 2, is_split) means: 2 args out-of-place, 2 original args, split status
return maybe_compile_function(
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
)
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -1113,7 +1163,7 @@ end

Generates a function that computes the observed value(s) `ts` in the system `sys`, while making the assumption that there are no cycles in the equations.

## Arguments
## Arguments
- `sys`: The system for which to generate the function
- `ts`: The symbolic observed values whose value should be computed

Expand Down
18 changes: 18 additions & 0 deletions lib/ModelingToolkitBase/test/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,15 @@ end
_t = tspan[2]
@test costfn(sol, prob.p, _t) (sol(0.6; idxs = x(t)) + 3)^2 + sol(0.3; idxs = x(t))^2

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [u0map; parammap], tspan)
sol = solve(bvp, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol)

costfn = ModelingToolkitBase.generate_bvp_cost(
lksys; expression = Val{false}, wrap_gfw = Val{true}
)
@test costfn(sol, bvp.p) (sol(0.6; idxs = x(t)) + 3)^2 + sol(0.3; idxs = x(t))^2

### With a parameter
@parameters t_c
costs = [y(t_c) + x(0.0), x(0.4)^2]
Expand All @@ -338,4 +347,13 @@ end
)
@test costfn(sol, prob.p, _t)
log(sol(0.56; idxs = y(t)) + sol(0.0; idxs = x(t))) - sol(0.4; idxs = x(t))^2

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [u0map; parammap], tspan)
sol = solve(bvp, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol)

costfn = ModelingToolkitBase.generate_bvp_cost(
lksys; expression = Val{false}, wrap_gfw = Val{true}
)
@test costfn(sol, bvp.p) log(sol(0.56; idxs = y(t)) + sol(0.0; idxs = x(t))) - sol(0.4; idxs = x(t))^2
end
Loading