Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "2.2"
Reactant = "0.2"
NonlinearSolveFirstOrder = "1.11"
NonlinearSolveQuasiNewton = "1.12"
NonlinearSolveSpectralMethods = "1.6"
Expand Down Expand Up @@ -145,6 +146,7 @@ PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
Expand All @@ -161,4 +163,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SafeTestsets", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLLogging"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "Reactant", "ReTestItems", "SafeTestsets", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLLogging"]
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
Expand Down Expand Up @@ -71,6 +72,7 @@ EnzymeCore = "0.8"
ExplicitImports = "1.10.1"
FastClosures = "0.3"
ForwardDiff = "0.10.36, 1"
ReactantCore = "0.1"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
LinearAlgebra = "1.10"
Expand Down
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Compat: @compat
using ConcreteStructs: @concrete
using FastClosures: @closure
using Preferences: @load_preference, @set_preferences!
using ReactantCore: @trace

using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector,
KnownJacobianSparsityDetector
Expand Down
6 changes: 3 additions & 3 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,19 @@ function SciMLBase.__solve(
end

function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
if cache.retcode == ReturnCode.InitialFailure
@trace if cache.retcode == ReturnCode.InitialFailure
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end

while not_terminated(cache)
@trace while not_terminated(cache)
CommonSolve.step!(cache)
end

# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
@trace if cache.retcode == ReturnCode.Default
cache.retcode = ifelse(
cache.nsteps ≥ cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success
)
Expand Down
25 changes: 14 additions & 11 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ end
function (cache::NonlinearTerminationModeCache)(
mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
)
if check_convergence(mode, du, u, uprev, abstol, reltol)
@trace if check_convergence(mode, du, u, uprev, abstol, reltol)
cache.retcode = ReturnCode.Success
return true
end
Expand All @@ -164,45 +164,47 @@ function (cache::NonlinearTerminationModeCache)(
end

# Protective Break
if !isfinite(objective)
@trace if !isfinite(objective)
cache.retcode = ReturnCode.Unstable
return true
end

# By default we turn this off since it have potential for false positives
if mode.protective_threshold !== nothing &&
@trace if mode.protective_threshold !== nothing &&
(objective > cache.initial_objective * mode.protective_threshold * length(du))
cache.retcode = ReturnCode.Unstable
return true
end

# Check if it is the best solution
if mode isa AbstractSafeBestNonlinearTerminationMode &&
@trace if mode isa AbstractSafeBestNonlinearTerminationMode &&
objective < cache.best_objective_value
cache.best_objective_value = objective
update_u!!(cache, u)
cache.saved_values !== nothing && length(args) ≥ 1 && (cache.saved_values = args)
end

# Main Termination Criteria
if objective ≤ criteria
@trace if objective ≤ criteria
cache.retcode = ReturnCode.Success
return true
end

# Terminate if we haven't improved for the last `patience_steps`
cache.nsteps += 1
cache.nsteps == 1 && (cache.initial_objective = objective)
@trace if cache.nsteps == 1
cache.initial_objective = objective
end
cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective

if objective ≤ mode.patience_objective_multiplier * criteria &&
@trace if objective ≤ mode.patience_objective_multiplier * criteria &&
cache.nsteps > mode.patience_steps
if cache.nsteps < length(cache.objectives_trace)
min_obj, max_obj = extrema(@view(cache.objectives_trace[1:(cache.nsteps)]))
else
min_obj, max_obj = extrema(cache.objectives_trace)
end
if min_obj < mode.min_max_factor * max_obj
@trace if min_obj < mode.min_max_factor * max_obj
if cache.leastsq
# If least squares, found a local minima thus success
cache.retcode = ReturnCode.StalledSuccess
Expand All @@ -223,15 +225,15 @@ function (cache::NonlinearTerminationModeCache)(
end
du_norm = L2_NORM(cache.u_diff_cache)
cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm
if cache.nsteps > mode.max_stalled_steps
@trace if cache.nsteps > mode.max_stalled_steps
max_step_norm = maximum(cache.step_norm_trace)
if mode isa AbsNormSafeTerminationMode ||
mode isa AbsNormSafeBestTerminationMode
stalled_step = max_step_norm ≤ abstol
else
stalled_step = max_step_norm ≤ reltol * (max_step_norm + cache.u0_norm)
end
if stalled_step
@trace if stalled_step
if cache.leastsq
cache.retcode = ReturnCode.StalledSuccess
else
Expand Down Expand Up @@ -329,11 +331,12 @@ function check_and_update!(cache, fu, u, uprev)
end

function check_and_update!(tc_cache, cache, fu, u, uprev, mode)
return if tc_cache(fu, u, uprev)
@trace if tc_cache(fu, u, uprev)
cache.retcode = tc_cache.retcode
update_from_termination_cache!(tc_cache, cache, mode, u)
cache.force_stop = true
end
return nothing
end

function update_from_termination_cache!(tc_cache, cache, u = get_u(cache))
Expand Down
2 changes: 2 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand Down Expand Up @@ -53,6 +54,7 @@ ExplicitImports = "1.9"
FastClosures = "0.3.2"
FiniteDiff = "2.24"
ForwardDiff = "0.10.36, 1"
ReactantCore = "0.1"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.3"
LinearAlgebra = "1.10"
Expand Down
1 change: 1 addition & 0 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module SimpleNonlinearSolve

using ConcreteStructs: @concrete
using PrecompileTools: @compile_workload, @setup_workload
using ReactantCore: @trace
using Reexport: @reexport
using Setfield: @set!

Expand Down
9 changes: 6 additions & 3 deletions lib/SimpleNonlinearSolve/src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ function SciMLBase.__solve(
fx = NLBUtils.evaluate_f(prob, x)
T = promote_type(eltype(fx), eltype(x))

iszero(fx) &&
@trace if iszero(fx)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
end

@bb xo = copy(x)
@bb δx = similar(x)
Expand Down Expand Up @@ -67,7 +68,7 @@ function SciMLBase.__solve(
ls_cache = nothing
end

for _ in 1:maxiters
@trace for _ in 1:maxiters
@bb δx = J⁻¹ × vec(fprev)
@bb δx .*= -1

Expand All @@ -84,7 +85,9 @@ function SciMLBase.__solve(

# Termination Checks
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

@bb J⁻¹δf = J⁻¹ × vec(δf)
d = dot(δx, J⁻¹δf)
Expand Down
16 changes: 11 additions & 5 deletions lib/SimpleNonlinearSolve/src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function SciMLBase.__solve(
@bb δf = copy(fx)

k = 0
while k < maxiters
@trace while k < maxiters
# Spectral parameter range check
σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)

Expand All @@ -121,16 +121,20 @@ function SciMLBase.__solve(
fx = NLBUtils.evaluate_f!!(prob, fx, x_cache)
fx_norm_new = L2_NORM(fx)^nexp

while k < maxiters
(fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)) && break
@trace while k < maxiters
@trace if fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)
break
end

α_tp = α_p^2 * fx_norm / (fx_norm_new + (T(2) * α_p - T(1)) * fx_norm)
@bb @. x_cache = x - α_m * d

fx = NLBUtils.evaluate_f!!(prob, fx, x_cache)
fx_norm_new = L2_NORM(fx)^nexp

(fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break
@trace if fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)
break
end

α_tm = α_m^2 * fx_norm / (fx_norm_new + (T(2) * α_m - T(1)) * fx_norm)
α_p = clamp(α_tp, τ_min * α_p, τ_max * α_p)
Expand All @@ -146,7 +150,9 @@ function SciMLBase.__solve(
@bb copyto!(x, x_cache)

solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

# Update spectral parameter
@bb @. δx = x - xo
Expand Down
17 changes: 11 additions & 6 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ function SciMLBase.__solve(
fx = NLBUtils.evaluate_f(prob, x)
T = promote_type(eltype(fx), eltype(x))

iszero(fx) &&
@trace if iszero(fx)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
end

abstol, reltol,
tc_cache = NonlinearSolveBase.init_termination_cache(
Expand All @@ -64,17 +65,19 @@ function SciMLBase.__solve(
end

J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
for _ in 1:maxiters
@trace for _ in 1:maxiters
NLBUtils.can_setindex(x) || (A = J)

# Factorize Once and Reuse
J_fact = if J isa Number
J
else
fact = LinearAlgebra.lu(J; check = false)
!LinearAlgebra.issuccess(fact) && return SciMLBase.build_solution(
prob, alg, x, fx; retcode = ReturnCode.Unstable
)
@trace if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_solution(
prob, alg, x, fx; retcode = ReturnCode.Unstable
)
end
fact
end

Expand All @@ -89,7 +92,9 @@ function SciMLBase.__solve(
cᵢ = NLBUtils.restructure(cᵢ, cᵢ_)

solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

@bb @. x += cᵢ
@bb copyto!(xo, x)
Expand Down
10 changes: 7 additions & 3 deletions lib/SimpleNonlinearSolve/src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ function SciMLBase.__solve(
J = one.(x)
@bb δx² = similar(x)

for _ in 1:maxiters
any(iszero, J) && (J = Utils.identity_jacobian!!(J))
@trace for _ in 1:maxiters
@trace if any(iszero, J)
J = Utils.identity_jacobian!!(J)
end

@bb @. δx = fprev / J

Expand All @@ -38,7 +40,9 @@ function SciMLBase.__solve(

# Termination Checks
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

@bb δx .*= -1
@bb @. δx² = δx^2 * J^2
Expand Down
6 changes: 4 additions & 2 deletions lib/SimpleNonlinearSolve/src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end
ls_cache = nothing
end

for i in 1:maxiters
@trace for i in 1:maxiters
if ls_cache === nothing
α = true
else
Expand All @@ -127,7 +127,9 @@ end

# Termination Checks
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

Uₚ = selectdim(U, 2, 1:min(η, i - 1))
Vᵀₚ = selectdim(Vᵀ, 1, 1:min(η, i - 1))
Expand Down
9 changes: 6 additions & 3 deletions lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ function SciMLBase.__solve(
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)

iszero(fx) &&
@trace if iszero(fx)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
end

abstol, reltol,
tc_cache = NonlinearSolveBase.init_termination_cache(
Expand All @@ -59,13 +60,15 @@ function SciMLBase.__solve(
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

for _ in 1:maxiters
@trace for _ in 1:maxiters
@bb copyto!(xo, x)
δx = NLBUtils.restructure(x, J \ NLBUtils.safe_vec(fx))
@bb x .-= δx

solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
@trace if solved
return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

fx = NLBUtils.evaluate_f!!(prob, fx, x)
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
Expand Down
Loading
Loading