diff --git a/lib/ModelingToolkitBase/src/problems/bvproblem.jl b/lib/ModelingToolkitBase/src/problems/bvproblem.jl index 80b96879fd..bc20bc2542 100644 --- a/lib/ModelingToolkitBase/src/problems/bvproblem.jl +++ b/lib/ModelingToolkitBase/src/problems/bvproblem.jl @@ -22,6 +22,9 @@ checkbounds, time_dependent_init = false, expression, kwargs... ) + fcost = generate_bvp_cost(sys; expression = Val{false}, wrap_gfw = Val{false}, + eval_expression, eval_module, cse, checkbounds) + stidxmap = Dict([v => i for (i, v) in enumerate(dvs)]) u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) : [stidxmap[k] for (k, v) in op if haskey(stidxmap, k)] @@ -30,12 +33,18 @@ wrap_gfw = Val{true}, cse, checkbounds ) + n_controls = length(unbound_inputs(sys)) + f_prototype = n_controls > 0 ? zeros(eltype(u0), length(dvs) - n_controls) : nothing + bcresid_prototype = zeros(eltype(u0), length(u0_idxs) + length(constraints(sys))) + + bvpfn = BVPFunction{iip}(fode, fbc; cost = fcost, f_prototype, bcresid_prototype) + if (length(constraints(sys)) + length(op) > length(dvs)) @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization." end kwargs = process_kwargs(sys; expression, kwargs...) - args = (; fode, fbc, u0, tspan, p) + args = (; bvpfn, u0, tspan, p) return maybe_codegen_scimlproblem(expression, BVProblem{iip}, args; kwargs...) end @@ -43,7 +52,6 @@ end function check_compatible_system(T::Type{BVProblem}, sys::System) check_time_dependent(sys, T) check_not_dde(sys) - check_no_cost(sys, T) check_no_jumps(sys, T) check_no_noise(sys, T) check_is_continuous(sys, T) diff --git a/lib/ModelingToolkitBase/src/systems/codegen.jl b/lib/ModelingToolkitBase/src/systems/codegen.jl index c96e5801a1..6bf8ea4e04 100644 --- a/lib/ModelingToolkitBase/src/systems/codegen.jl +++ b/lib/ModelingToolkitBase/src/systems/codegen.jl @@ -699,6 +699,56 @@ function generate_cost( ) end +""" + $(TYPEDSIGNATURES) + +Generate the cost function for a BVP [`System`](@ref). The generated function has the +signature `cost(sol, p)` where `sol` is a solution interpolation object (callable as +`sol(t)` to get state at time `t`) and `p` is the parameter object. + +# Keyword Arguments + +$GENERATE_X_KWARGS + +All other keyword arguments are forwarded to [`build_function_wrapper`](@ref). +""" +function generate_bvp_cost( + sys::System; expression = Val{true}, wrap_gfw = Val{false}, + eval_expression = false, eval_module = @__MODULE__, cse = true, + checkbounds = false, kwargs... + ) + obj = cost(sys) + _iszero(obj) && return nothing + + iv = get_iv(sys) + sts = unknowns(sys) + ps = reorder_parameters(sys) + stidxmap = Dict([v => i for (i, v) in enumerate(sts)]) + + # Substitute x(t_val) -> BVP_SOLUTION(t_val)[idx] for all state evaluations + costsubs = Dict() + get_constraint_unknown_subs!(costsubs, [obj], stidxmap, iv, BVP_SOLUTION) + obj = substitute(obj, costsubs) + + # Build function with signature (sol, p) where sol = BVP_SOLUTION + # The histfn mechanism replaces BVP_SOLUTION with the sol argument + res = build_function_wrapper( + sys, obj, ps...; + expression = Val{true}, + p_start = 1, # sol goes before parameters + p_end = length(ps), + wrap_delays = true, + histfn = (p, t) -> BVP_SOLUTION(t), + histfn_symbolic = BVP_SOLUTION, + cse, checkbounds, kwargs... + ) + + # (2, 2, is_split) means: 2 args out-of-place, 2 original args, split status + return maybe_compile_function( + expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module + ) +end + """ $(TYPEDSIGNATURES) @@ -1113,7 +1163,7 @@ end Generates a function that computes the observed value(s) `ts` in the system `sys`, while making the assumption that there are no cycles in the equations. -## Arguments +## Arguments - `sys`: The system for which to generate the function - `ts`: The symbolic observed values whose value should be computed diff --git a/lib/ModelingToolkitBase/test/bvproblem.jl b/lib/ModelingToolkitBase/test/bvproblem.jl index 44a714ebe4..e8a87db5b0 100644 --- a/lib/ModelingToolkitBase/test/bvproblem.jl +++ b/lib/ModelingToolkitBase/test/bvproblem.jl @@ -324,6 +324,15 @@ end _t = tspan[2] @test costfn(sol, prob.p, _t) ≈ (sol(0.6; idxs = x(t)) + 3)^2 + sol(0.3; idxs = x(t))^2 + bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [u0map; parammap], tspan) + sol = solve(bvp, MIRK4(), dt = 0.05) + @test SciMLBase.successful_retcode(sol) + + costfn = ModelingToolkitBase.generate_bvp_cost( + lksys; expression = Val{false}, wrap_gfw = Val{true} + ) + @test costfn(sol, bvp.p) ≈ (sol(0.6; idxs = x(t)) + 3)^2 + sol(0.3; idxs = x(t))^2 + ### With a parameter @parameters t_c costs = [y(t_c) + x(0.0), x(0.4)^2] @@ -338,4 +347,13 @@ end ) @test costfn(sol, prob.p, _t) ≈ log(sol(0.56; idxs = y(t)) + sol(0.0; idxs = x(t))) - sol(0.4; idxs = x(t))^2 + + bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [u0map; parammap], tspan) + sol = solve(bvp, MIRK4(), dt = 0.05) + @test SciMLBase.successful_retcode(sol) + + costfn = ModelingToolkitBase.generate_bvp_cost( + lksys; expression = Val{false}, wrap_gfw = Val{true} + ) + @test costfn(sol, bvp.p) ≈ log(sol(0.56; idxs = y(t)) + sol(0.0; idxs = x(t))) - sol(0.4; idxs = x(t))^2 end