Skip to content

Improve error message when tuple is passed in for parameter estimation #1011

@mchitre

Description

@mchitre

Is your feature request related to a problem? Please describe.

If you pass a tuple (rather than an array) to ODEProblem for problem parameters, and then try to optimize over those (e.g. using NNODE(...; param_estim=true)), you get an error that sends you on a wild-goose chase hunting for accidental Float64 by saying:

ERROR: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

when the problem has nothing to do with that.

Describe the solution you’d like

We should explicitly detect the tuple and give an error asking the user to pass in parameters as an array if he wants to optimize over them.

Additional context

MWE to trigger the error:

import OptimizationOptimJL: BFGS
using DifferentialEquations
using NeuralPDE
using Random
using Lux

dudt(u, (r, K), t) = r * u * (K - u)
prob = ODEProblem(dudt, 1f0, (0f0, 1f0), (1f-3, 1f0))
u_model = Chain(
  Dense(1 => 16, tanh),
  Dense(16 => 16, tanh),
  Dense(16 => 1)
)
ps, state = Lux.setup(MersenneTwister(0), u_model)
loss(sol, ps) = 0f0
alg = NNODE(u_model, BFGS(), ps; autodiff=true, param_estim=true, additional_loss=loss)
sol = solve(prob, alg; maxiters=100)

Misleading error:

ERROR: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the `ArrayPartition`
from RecursiveArrayTools.jl. For example:

using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!

Element type:
Any
Stacktrace:
  [1] init(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{…})
    @ OptimizationBase ~/.julia/packages/OptimizationBase/ivotG/src/solve.jl:191
  [2] solve(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{…})
    @ OptimizationBase ~/.julia/packages/OptimizationBase/ivotG/src/solve.jl:119
  [3] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/.julia/packages/NeuralPDE/2Czof/src/ode_solve.jl:451
  [4] __solve
    @ ~/.julia/packages/NeuralPDE/2Czof/src/ode_solve.jl:340 [inlined]
  [5] #solve_call#22
    @ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:172 [inlined]
  [6] solve_call
    @ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:137 [inlined]
  [7] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Float32, p::Tuple{…}, args::NNODE{…}; originator::SciMLBase.ChainRulesOriginator, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:610
  [8] solve_up
    @ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:599 [inlined]
  [9] #solve#28
    @ ~/.julia/packages/DiffEqBase/949EN/src/solve.jl:583 [inlined]
 [10] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions