Skip to content
4 changes: 3 additions & 1 deletion lib/BoundaryValueDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqCore"
uuid = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
authors = ["Qingyu Qu <[email protected]>"]
version = "1.11.1"
authors = ["Qingyu Qu <[email protected]>"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -20,6 +20,7 @@ PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
Expand All @@ -45,6 +46,7 @@ PreallocationTools = "0.4.24"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2"
SciMLBase = "2.130.0"
SciMLStructures = "1.7.0"
Setfield = "1"
SparseArrays = "1.10"
SparseConnectivityTracer = "0.6.13, 1"
Expand Down
1 change: 1 addition & 0 deletions lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Setfield: @set!, @set
using SparseArrays: sparse
using SparseConnectivityTracer: SparseConnectivityTracer, TracerLocalSparsityDetector
using SparseMatrixColorings: GreedyColoringAlgorithm
using SciMLStructures: SciMLStructures

@reexport using NonlinearSolveFirstOrder, SciMLBase

Expand Down
39 changes: 32 additions & 7 deletions lib/BoundaryValueDiffEqCore/src/internal_problems.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
@inline __default_cost(::Nothing) = (x, p) -> 0.0
@inline __default_cost(f) = f
@inline __build_cost(::Nothing, _, _, _) = (x, p) -> 0.0
@inline function __build_cost(fun, cache, mesh, M)
cost_fun = function (u, p)
# simple recursive unflatten
newy = [u[i:(i + M - 1)] for i in 1:M:(length(u) - M + 1)]
eval_sol = EvalSol(newy, mesh, cache)
return fun(eval_sol, p)
@inline __build_cost(::Nothing, cache, mesh, M; kwargs...) = (x, p) -> 0.0
@inline function __build_cost(fun, cache, mesh, M; fit_parameters = false, p = nothing)
if fit_parameters && SciMLStructures.isscimlstructure(p)
# When fit_parameters=true, the state vector is augmented with tunable params
# Extract them and use SciMLStructures.replace to update p for the cost function
tunable_part, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
l_params = length(tunable_part)
length_u = M - l_params
cost_fun = @views function (u, p_orig)
newy = eachcol(reshape(u, M, :))
# Extract tunable params from first mesh point (same at all points)
params_from_u = u[(length_u + 1):M]
new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p_orig, params_from_u)
eval_sol = EvalSol(newy, mesh, cache)
return fun(eval_sol, new_p)
end
elseif fit_parameters && !isnothing(p)
length_u = M - length(p)
cost_fun = @views function (u, p)
# When fit_parameters=true, the state vector is augmented with tunable params
newy = eachcol(reshape(u, M, :))
params_from_u = u[(length_u + 1):M]
eval_sol = EvalSol(newy, mesh, cache)
return fun(eval_sol, params_from_u)
end
else
cost_fun = @views function (u, p)
# simple recursive unflatten
newy = eachcol(reshape(u, M, :))
eval_sol = EvalSol(newy, mesh, cache)
return fun(eval_sol, p)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original unflatten operation is really allocating a lot, maybe we can change to use iterators to construct immediate solution from one-dimensional state vectors, for example reshape state vectors and use eachcol iterator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to that in d5f757c.

end
return cost_fun
end
Expand Down
17 changes: 13 additions & 4 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ function __extract_problem_details(
if fit_parameters
prob.p isa SciMLBase.NullParameters &&
throw(ArgumentError("`fit_parameters` is true but `prob.p` is not set."))
new_u = vcat(u0, prob.p)
new_u = vcat(u0, __tunable_part(prob.p))
return Val(false), eltype(new_u), length(new_u), Int(cld(t₁ - t₀, dt)), new_u
end
return Val(false), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0
Expand All @@ -319,7 +319,7 @@ function __extract_problem_details(
if fit_parameters
prob.p isa SciMLBase.NullParameters &&
throw(ArgumentError("`fit_parameters` is true but `prob.p` is not set."))
new_u = vcat(_u0, prob.p)
new_u = vcat(_u0, __tunable_part(prob.p))
return Val(false), eltype(new_u), length(new_u), Int(cld(t₁ - t₀, dt)), new_u
end
return Val(true), eltype(_u0), length(_u0), (length(_t) - 1), _u0
Expand All @@ -330,7 +330,7 @@ function __initial_guess(f::F, p::P, t::T; fit_parameters = false) where {F, P,
if fit_parameters
p isa SciMLBase.NullParameters &&
throw(ArgumentError("`fit_parameters` is true but `prob.p` is not set."))
return vcat(f(p, t), p)
return vcat(f(p, t), __tunable_part(p))
end
return f(p, t)
elseif hasmethod(f, Tuple{T})
Expand All @@ -343,14 +343,23 @@ function __initial_guess(f::F, p::P, t::T; fit_parameters = false) where {F, P,
if fit_parameters
p isa SciMLBase.NullParameters &&
throw(ArgumentError("`fit_parameters` is true but `prob.p` is not set."))
return vcat(f(t), p)
return vcat(f(t), __tunable_part(p))
end
return f(t)
else
throw(ArgumentError("`initial_guess` must be a function of the form `f(p, t)`"))
end
end

function __tunable_part(p)
if SciMLStructures.isscimlstructure(p)
part, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
return part
else
p
end
end

function __get_bcresid_prototype(prob::BVProblem, u)
return __get_bcresid_prototype(prob.problem_type, prob, u)
end
Expand Down
6 changes: 4 additions & 2 deletions lib/BoundaryValueDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[sources.BoundaryValueDiffEqCore]
path = "../BoundaryValueDiffEqCore"
[sources]
BoundaryValueDiffEqCore = {path = "../BoundaryValueDiffEqCore"}

[compat]
ADTypes = "1.14"
Expand Down Expand Up @@ -55,6 +56,7 @@ ReTestItems = "1.23.1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2"
SciMLBase = "2.130.0"
SciMLStructures = "1.7.0"
Setfield = "1.1.1"
SparseArrays = "1.10"
StaticArrays = "1.9.8"
Expand Down
4 changes: 3 additions & 1 deletion lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __split_kwargs, _sparse_like,
get_dense_ad, __internal_optimization_problem,
__internal_solve, __default_sparsity_detector, __build_cost
__internal_solve, __default_sparsity_detector, __build_cost,
__tunable_part

using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
Expand All @@ -44,6 +45,7 @@ using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, StandardBVProblem, __so
_unwrap_val
using Setfield: @set!, @set
using SparseArrays: sparse
using SciMLStructures: SciMLStructures

const DI = DifferentiationInterface

Expand Down
Loading
Loading