-
-
Notifications
You must be signed in to change notification settings - Fork 233
Open
Description
Is your feature request related to a problem? Please describe.
If you pass a tuple (rather than an array) to ODEProblem for problem parameters, and then try to optimize over those (e.g. using NNODE(...; param_estim=true)), you get an error that sends you on a wild-goose chase hunting for accidental Float64 by saying:
ERROR: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!
when the problem has nothing to do with that.
Describe the solution you’d like
We should explicitly detect the tuple and give an error asking the user to pass in parameters as an array if he wants to optimize over them.
Additional context
MWE to trigger the error:
import OptimizationOptimJL: BFGS
using DifferentialEquations
using NeuralPDE
using Random
using Lux
dudt(u, (r, K), t) = r * u * (K - u)
prob = ODEProblem(dudt, 1f0, (0f0, 1f0), (1f-3, 1f0))
u_model = Chain(
Dense(1 => 16, tanh),
Dense(16 => 16, tanh),
Dense(16 => 1)
)
ps, state = Lux.setup(MersenneTwister(0), u_model)
loss(sol, ps) = 0f0
alg = NNODE(u_model, BFGS(), ps; autodiff=true, param_estim=true, additional_loss=loss)
sol = solve(prob, alg; maxiters=100)Misleading error:
ERROR: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!
If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the `ArrayPartition`
from RecursiveArrayTools.jl. For example:
using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!
Element type:
Any
Stacktrace:
[1] init(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{…})
@ OptimizationBase ~/.julia/packages/OptimizationBase/ivotG/src/solve.jl:191
[2] solve(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{…})
@ OptimizationBase ~/.julia/packages/OptimizationBase/ivotG/src/solve.jl:119
[3] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
@ NeuralPDE ~/.julia/packages/NeuralPDE/2Czof/src/ode_solve.jl:451
[4] __solve
@ ~/.julia/packages/NeuralPDE/2Czof/src/ode_solve.jl:340 [inlined]
[5] #solve_call#22
@ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:172 [inlined]
[6] solve_call
@ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:137 [inlined]
[7] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Float32, p::Tuple{…}, args::NNODE{…}; originator::SciMLBase.ChainRulesOriginator, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:610
[8] solve_up
@ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:599 [inlined]
[9] #solve#28
@ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:583 [inlined]
[10] top-level scope
@ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.
Metadata
Metadata
Assignees
Labels
No labels