Skip to content

Commit bd1d2d3

Browse files
SebastianM-Cclaude
andcommitted
Update the __build_cost to take into account parameter updates when using fit_parameters
Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent dedabd6 commit bd1d2d3

File tree

3 files changed

+71
-55
lines changed

3 files changed

+71
-55
lines changed

lib/BoundaryValueDiffEqCore/src/internal_problems.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
@inline __default_cost(::Nothing) = (x, p) -> 0.0
22
@inline __default_cost(f) = f
3-
@inline __build_cost(::Nothing, _, _, _) = (x, p) -> 0.0
4-
@inline function __build_cost(fun, cache, mesh, M)
5-
cost_fun = function (u, p)
6-
# simple recursive unflatten
7-
newy = [u[i:(i + M - 1)] for i in 1:M:(length(u) - M + 1)]
8-
eval_sol = EvalSol(newy, mesh, cache)
9-
return fun(eval_sol, p)
3+
@inline __build_cost(::Nothing, cache, mesh, M; kwargs...) = (x, p) -> 0.0
4+
@inline function __build_cost(fun, cache, mesh, M; fit_parameters = false, p = nothing)
5+
if fit_parameters && p !== nothing
6+
# When fit_parameters=true, the state vector is augmented with tunable params
7+
# Extract them and use SciMLStructures.replace to update p for the cost function
8+
tunable_part, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
9+
l_params = length(tunable_part)
10+
length_u = M - l_params
11+
cost_fun = @views function (u, p_orig)
12+
newy = [u[i:(i + M - 1)] for i in 1:M:(length(u) - M + 1)]
13+
# Extract tunable params from first mesh point (same at all points)
14+
params_from_u = u[(length_u + 1):M]
15+
new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p_orig, params_from_u)
16+
eval_sol = EvalSol(newy, mesh, cache)
17+
return fun(eval_sol, new_p)
18+
end
19+
else
20+
cost_fun = @views function (u, p)
21+
# simple recursive unflatten
22+
newy = [u[i:(i + M - 1)] for i in 1:M:(length(u) - M + 1)]
23+
eval_sol = EvalSol(newy, mesh, cache)
24+
return fun(eval_sol, p)
25+
end
1026
end
1127
return cost_fun
1228
end

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,8 @@ function __construct_problem(
645645
end
646646

647647
function __construct_problem(
648-
cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C,
649-
loss::LF, ::StandardBVProblem, ::Val{true}
650-
) where {iip, BC, C, LF}
648+
cache::FIRKCacheExpand{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
649+
loss::LF, ::StandardBVProblem, ::Val{true}) where {iip, T, DC, fit_parameters, BC, C, LF}
651650
(; prob, alg, stage, bcresid_prototype, f_prototype) = cache
652651
(; jac_alg) = alg
653652
(; bc_diffmode) = jac_alg
@@ -714,7 +713,8 @@ function __construct_problem(
714713
)
715714
end
716715

717-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
716+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
717+
fit_parameters, p = cache.p)
718718

719719
resid_prototype = vcat(resid_bc, resid_collocation)
720720
return __construct_internal_problem(
@@ -725,9 +725,8 @@ function __construct_problem(
725725
end
726726

727727
function __construct_problem(
728-
cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C,
729-
loss::LF, ::StandardBVProblem, ::Val{false}
730-
) where {iip, BC, C, LF}
728+
cache::FIRKCacheExpand{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
729+
loss::LF, ::StandardBVProblem, ::Val{false}) where {iip, T, DC, fit_parameters, BC, C, LF}
731730
(; prob, alg, stage, bcresid_prototype, f_prototype) = cache
732731
(; jac_alg) = alg
733732
(; bc_diffmode) = jac_alg
@@ -822,7 +821,8 @@ function __construct_problem(
822821
)
823822
end
824823

825-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
824+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
825+
fit_parameters, p = cache.p)
826826

827827
resid_prototype = vcat(resid_bc, resid_collocation)
828828
return __construct_internal_problem(
@@ -833,9 +833,8 @@ function __construct_problem(
833833
end
834834

835835
function __construct_problem(
836-
cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C,
837-
loss::LF, ::TwoPointBVProblem, ::Val{true}
838-
) where {iip, BC, C, LF}
836+
cache::FIRKCacheExpand{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
837+
loss::LF, ::TwoPointBVProblem, ::Val{true}) where {iip, T, DC, fit_parameters, BC, C, LF}
839838
(; jac_alg) = cache.alg
840839
(; stage, bcresid_prototype, f_prototype) = cache
841840
N = length(cache.mesh)
@@ -884,7 +883,8 @@ function __construct_problem(
884883
)
885884
end
886885

887-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
886+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
887+
fit_parameters, p = cache.p)
888888

889889
resid_prototype = copy(resid)
890890
return __construct_internal_problem(
@@ -895,9 +895,8 @@ function __construct_problem(
895895
end
896896

897897
function __construct_problem(
898-
cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C,
899-
loss::LF, ::TwoPointBVProblem, ::Val{false}
900-
) where {iip, BC, C, LF}
898+
cache::FIRKCacheExpand{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
899+
loss::LF, ::TwoPointBVProblem, ::Val{false}) where {iip, T, DC, fit_parameters, BC, C, LF}
901900
(; jac_alg) = cache.alg
902901
(; stage, bcresid_prototype, f_prototype, prob) = cache
903902
N = length(cache.mesh)
@@ -961,7 +960,8 @@ function __construct_problem(
961960
)
962961
end
963962

964-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
963+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
964+
fit_parameters, p = cache.p)
965965

966966
resid_prototype = copy(resid)
967967
return __construct_internal_problem(
@@ -972,9 +972,8 @@ function __construct_problem(
972972
end
973973

974974
function __construct_problem(
975-
cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C,
976-
loss::LF, ::StandardBVProblem, ::Val{true}
977-
) where {iip, BC, C, LF}
975+
cache::FIRKCacheNested{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
976+
loss::LF, ::StandardBVProblem, ::Val{true}) where {iip, T, DC, fit_parameters, BC, C, LF}
978977
(; jac_alg) = cache.alg
979978
(; bc_diffmode) = jac_alg
980979
(; bcresid_prototype, f_prototype) = cache
@@ -1038,7 +1037,8 @@ function __construct_problem(
10381037
)
10391038
end
10401039

1041-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
1040+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
1041+
fit_parameters, p = cache.p)
10421042

10431043
resid_prototype = vcat(resid_bc, resid_collocation)
10441044
return __construct_internal_problem(
@@ -1048,9 +1048,8 @@ function __construct_problem(
10481048
end
10491049

10501050
function __construct_problem(
1051-
cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C,
1052-
loss::LF, ::StandardBVProblem, ::Val{false}
1053-
) where {iip, BC, C, LF}
1051+
cache::FIRKCacheNested{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
1052+
loss::LF, ::StandardBVProblem, ::Val{false}) where {iip, T, DC, fit_parameters, BC, C, LF}
10541053
(; jac_alg) = cache.alg
10551054
(; bc_diffmode) = jac_alg
10561055
(; bcresid_prototype, f_prototype, prob) = cache
@@ -1139,7 +1138,8 @@ function __construct_problem(
11391138
)
11401139
end
11411140

1142-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
1141+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
1142+
fit_parameters, p = cache.p)
11431143

11441144
resid_prototype = vcat(resid_bc, resid_collocation)
11451145
return __construct_internal_problem(
@@ -1149,9 +1149,8 @@ function __construct_problem(
11491149
end
11501150

11511151
function __construct_problem(
1152-
cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C,
1153-
loss::LF, ::TwoPointBVProblem, ::Val{true}
1154-
) where {iip, BC, C, LF}
1152+
cache::FIRKCacheNested{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
1153+
loss::LF, ::TwoPointBVProblem, ::Val{true}) where {iip, T, DC, fit_parameters, BC, C, LF}
11551154
(; jac_alg) = cache.alg
11561155
(; bcresid_prototype, f_prototype) = cache
11571156
N = length(cache.mesh)
@@ -1198,7 +1197,8 @@ function __construct_problem(
11981197
)
11991198
end
12001199

1201-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
1200+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
1201+
fit_parameters, p = cache.p)
12021202

12031203
resid_prototype = copy(resid)
12041204
return __construct_internal_problem(
@@ -1208,9 +1208,8 @@ function __construct_problem(
12081208
end
12091209

12101210
function __construct_problem(
1211-
cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C,
1212-
loss::LF, ::TwoPointBVProblem, ::Val{false}
1213-
) where {iip, BC, C, LF}
1211+
cache::FIRKCacheNested{iip, T, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C,
1212+
loss::LF, ::TwoPointBVProblem, ::Val{false}) where {iip, T, DC, fit_parameters, BC, C, LF}
12141213
(; jac_alg) = cache.alg
12151214
(; bcresid_prototype, f_prototype, prob) = cache
12161215
N = length(cache.mesh)
@@ -1263,7 +1262,8 @@ function __construct_problem(
12631262
)
12641263
end
12651264

1266-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
1265+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
1266+
fit_parameters, p = cache.p)
12671267

12681268
resid_prototype = copy(resid)
12691269
return __construct_internal_problem(

lib/BoundaryValueDiffEqMIRK/src/mirk.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,8 @@ end
552552
end
553553

554554
function __construct_problem(
555-
cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF,
556-
::StandardBVProblem, constraint::Val{true}
557-
) where {iip, BC, C, LF}
555+
cache::MIRKCache{iip, T, UB, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C, loss::LF,
556+
::StandardBVProblem, constraint::Val{true}) where {iip, T, UB, DC, fit_parameters, BC, C, LF}
558557
(; jac_alg) = cache.alg
559558
(; f_prototype, bcresid_prototype, prob) = cache
560559
(; bc_diffmode) = jac_alg
@@ -620,7 +619,8 @@ function __construct_problem(
620619
)
621620
end
622621

623-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
622+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
623+
fit_parameters, p = cache.p)
624624

625625
resid_prototype = vcat(resid_bc, resid_collocation)
626626
return __construct_internal_problem(
@@ -631,9 +631,8 @@ end
631631

632632
# Dispatch for problems with constraints
633633
function __construct_problem(
634-
cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF,
635-
::StandardBVProblem, constraint::Val{false}
636-
) where {iip, BC, C, LF}
634+
cache::MIRKCache{iip, T, UB, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C, loss::LF,
635+
::StandardBVProblem, constraint::Val{false}) where {iip, T, UB, DC, fit_parameters, BC, C, LF}
637636
(; jac_alg) = cache.alg
638637
(; f_prototype, bcresid_prototype, prob) = cache
639638
(; bc_diffmode) = jac_alg
@@ -725,7 +724,8 @@ function __construct_problem(
725724
)
726725
end
727726

728-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
727+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
728+
fit_parameters, p = cache.p)
729729

730730
return __construct_internal_problem(
731731
prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, resid_prototype,
@@ -788,9 +788,8 @@ function __mirk_mpoint_jacobian(
788788
end
789789

790790
function __construct_problem(
791-
cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF,
792-
::TwoPointBVProblem, constraint::Val{true}
793-
) where {iip, BC, C, LF}
791+
cache::MIRKCache{iip, T, UB, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C, loss::LF,
792+
::TwoPointBVProblem, constraint::Val{true}) where {iip, T, UB, DC, fit_parameters, BC, C, LF}
794793
(; jac_alg) = cache.alg
795794
(; f_prototype, bcresid_prototype, prob) = cache
796795
N = length(cache.mesh)
@@ -834,7 +833,8 @@ function __construct_problem(
834833
) -> __mirk_2point_jacobian(u, jac_prototype, diffmode, diffcache, loss, p)
835834
end
836835

837-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
836+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
837+
fit_parameters, p = cache.p)
838838

839839
resid_prototype = copy(resid)
840840
return __construct_internal_problem(
@@ -844,9 +844,8 @@ function __construct_problem(
844844
end
845845

846846
function __construct_problem(
847-
cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF,
848-
::TwoPointBVProblem, constraint::Val{false}
849-
) where {iip, BC, C, LF}
847+
cache::MIRKCache{iip, T, UB, DC, fit_parameters}, y, loss_bc::BC, loss_collocation::C, loss::LF,
848+
::TwoPointBVProblem, constraint::Val{false}) where {iip, T, UB, DC, fit_parameters, BC, C, LF}
850849
(; jac_alg) = cache.alg
851850
(; f_prototype, bcresid_prototype, prob) = cache
852851
N = length(cache.mesh)
@@ -894,7 +893,8 @@ function __construct_problem(
894893
) -> __mirk_2point_jacobian(u, jac_prototype, diffmode, diffcache, loss, p)
895894
end
896895

897-
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M)
896+
cost_fun = __build_cost(prob.f.cost, cache, cache.mesh, cache.M;
897+
fit_parameters, p = cache.p)
898898

899899
resid_prototype = copy(resid)
900900
return __construct_internal_problem(

0 commit comments

Comments
 (0)