Skip to content

Weird remake issue when trying to solve OptimizationProblem involving simulating MTK ODE model #4234

@TorkelE

Description

@TorkelE

Not sure what is happening, but this is something that have appeared in the Catalyst docs.

using ModelingToolkitBase
using ModelingToolkitBase: t_nounits as t, D_nounits as D
using OrdinaryDiffEqRosenbrock
using Optimization
using OptimizationOptimisers 
using SciMLSensitivity

# Declare model.
@variables X(t) Y(t)
@parameters A B
eqs = [
    D(X) ~ A - (B + 1) * X + X^2 * Y,
    D(Y) ~ B * X - X^2 * Y
]
@mtkcompile brusselator = System(eqs, t)

# Create Base ODEProblem.
p_real = [:A => 1.0, :B => 2.0]
u0 = [:X => 1.0, :Y => 1.0]
tend = 30.0
prob_real = ODEProblem(brusselator, u0, tend, p_real)

# Create synthetic data.
sample_times = range(0.0; stop = tend, length = 100)
sol_real = solve(prob_real, Rosenbrock23(); tstops = sample_times)
sample_vals = Array(sol_real(sample_times))
sample_vals .*= (1 .+ .1 * rand(Float64, size(sample_vals)) .- .05)

# Set OptimizationProblem Parameters.
prob_base = ODEProblem(brusselator, [:X => 1.0, :Y => 1.0], 10.0, [:A => 1.0, :B => 2.0])
set_p = ModelingToolkitBase.setp_oop(prob_base, [:A, :B])
loss_ps = (set_p, prob_base, sample_times, sample_vals)

# Define loss function.
function loss(p, (set_p, prob_base, sample_times, sample_vals))
    p = set_p(prob_base, p)
    newprob = remake(prob_base; p) # Error happens here.
    sol = Array(solve(newprob, Rosenbrock23(); saveat = sample_times, verbose = false, maxiters = 10000))
    loss = sum(abs2, sol .- sample_vals[:, 1:size(sol,2)])
    return loss
end

# Create OptimizationProblem.
optf = OptimizationFunction(loss, Optimization.AutoZygote())
ps_init = [5.0, 5.0]
optprob = OptimizationProblem(optf, ps_init, loss_ps)

# Attempts to solve it.
sol = solve(optprob, ADAM(0.1); maxiters = 100)

yields:

ERROR: Tuple field type cannot be Union{}
Stacktrace:
  [1] may_bc_derivatives(::Type{Union{}}, f::ModelingToolkitBase.PConstructorApplicator{typeof(identity)}, args::Tuple{})
    @ ChainRules C:\Users\Torkel\.julia\packages\ChainRules\14CDN\src\rulesets\Base\broadcast.jl:51
  [2] rrule(cfg::Zygote.ZygoteRuleConfig{…}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{…}, f::ModelingToolkitBase.PConstructorApplicator{…}, args::Tuple{})
    @ ChainRules C:\Users\Torkel\.julia\packages\ChainRules\14CDN\src\rulesets\Base\broadcast.jl:36
  [3] chain_rrule
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\chainrules.jl:234 [inlined]
  [4] macro expansion
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:-1 [inlined]
  [5] _pullback(::Zygote.Context{…}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{…}, ::ModelingToolkitBase.PConstructorApplicator{…}, ::Tuple{})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
  [6] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:1019
  [7] adjoint
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\lib.jl:211 [inlined]
  [8] _pullback
    @ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
  [9] broadcasted
    @ .\broadcast.jl:1346 [inlined]
 [10] adjoint
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\broadcast.jl:246 [inlined]
 [11] _pullback
    @ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
 [12] fallback_Fix1
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\base.jl:237 [inlined]
 [13] _pullback(ctx::Zygote.Context{…}, f::Zygote.var"#fallback_Fix1#fallback_Fix1##0"{…}, args::Tuple{})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
 [14] adjoint
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\base.jl:238 [inlined]
 [15] _pullback(__context__::Zygote.Context{…}, g::Base.Fix1{…}, y::Tuple{})
    @ Zygote C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67
 [16] _pullback(::Zygote.Context{…}, ::Base.var"##_#54", ::@Kwargs{}, ::ComposedFunction{…}, ::SymbolicIndexingInterface.ProblemState{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
 [17] _apply
    @ .\boot.jl:1019 [inlined]
 [18] adjoint
    @ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\lib.jl:211 [inlined]
 [19] _pullback
    @ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
 [20] ComposedFunction
    @ .\operators.jl:1096 [inlined]
 [21] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::SymbolicIndexingInterface.ProblemState{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
 [22] _getter
    @ C:\Users\Torkel\.julia\packages\ModelingToolkitBase\B3XVN\src\systems\problem_utils.jl:822 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::ModelingToolkitBase.var"#_getter#772"{…}, ::SymbolicIndexingInterface.ProblemState{…}, ::NonlinearProblem{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
 [24] _pullback(::Zygote.Context{…}, ::ModelingToolkitBase.ReconstructInitializeprob{…}, ::SymbolicIndexingInterface.ProblemState{…}, ::NonlinearProblem{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
 [25] _pullback(::Zygote.Context{…}, ::typeof(SciMLBase.remake_initialization_data), ::System, ::ODEFunction{…}, ::Missing, ::Float64, ::MTKParameters{…}, ::Vector{…}, ::MTKParameters{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
 [26] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##remake#869", ::Missing, ::Missing, ::Missing, ::MTKParameters{…}, ::Missing, ::Bool, ::Type{…}, ::Bool, ::Nothing, ::@Kwargs{}, ::typeof(remake), ::ODEProblem{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
 [27] remake
    @ C:\Users\Torkel\.julia\packages\SciMLBase\tmJhj\src\remake.jl:222 [inlined]
 [28] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(remake), ::ODEProblem{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
 [29] loss
    @ .\Untitled-1:37 [inlined]
 [30] _pullback(::Zygote.Context{…}, ::typeof(loss), ::Vector{…}, ::Tuple{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
 [31] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:96
 [32] pullback(::Function, ::Vector{…}, ::Tuple{…})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:94
 [33] withgradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:211
 [34] value_and_gradient
    @ C:\Users\Torkel\.julia\packages\DifferentiationInterface\M8gIf\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:115 [inlined]
 [35] value_and_gradient!(f::Function, grad::Vector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::Vector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt C:\Users\Torkel\.julia\packages\DifferentiationInterface\M8gIf\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:131
 [36] (::OptimizationZygoteExt.var"#fg!#19"{…})(res::Vector{…}, θ::Vector{…})
    @ OptimizationZygoteExt C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\ext\OptimizationZygoteExt.jl:57
 [37] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers C:\Users\Torkel\.julia\packages\OptimizationOptimisers\atr4L\src\OptimizationOptimisers.jl:83
 [38] solve!(cache::OptimizationCache{…})
    @ OptimizationBase C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\src\solve.jl:216
 [39] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
    @ OptimizationBase C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\src\solve.jl:98
 [40] top-level scope
    @ Untitled-1:50
Some type information was truncated. Use `show(err)` to see complete types.

Noteably, all these works:

loss(ps_init, loss_ps)
optf(ps_init, loss_ps)
optprob.f(ps_init, loss_ps)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions