-
-
Notifications
You must be signed in to change notification settings - Fork 241
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Not sure what is happening, but this is something that have appeared in the Catalyst docs.
using ModelingToolkitBase
using ModelingToolkitBase: t_nounits as t, D_nounits as D
using OrdinaryDiffEqRosenbrock
using Optimization
using OptimizationOptimisers
using SciMLSensitivity
# Declare model.
@variables X(t) Y(t)
@parameters A B
eqs = [
D(X) ~ A - (B + 1) * X + X^2 * Y,
D(Y) ~ B * X - X^2 * Y
]
@mtkcompile brusselator = System(eqs, t)
# Create Base ODEProblem.
p_real = [:A => 1.0, :B => 2.0]
u0 = [:X => 1.0, :Y => 1.0]
tend = 30.0
prob_real = ODEProblem(brusselator, u0, tend, p_real)
# Create synthetic data.
sample_times = range(0.0; stop = tend, length = 100)
sol_real = solve(prob_real, Rosenbrock23(); tstops = sample_times)
sample_vals = Array(sol_real(sample_times))
sample_vals .*= (1 .+ .1 * rand(Float64, size(sample_vals)) .- .05)
# Set OptimizationProblem Parameters.
prob_base = ODEProblem(brusselator, [:X => 1.0, :Y => 1.0], 10.0, [:A => 1.0, :B => 2.0])
set_p = ModelingToolkitBase.setp_oop(prob_base, [:A, :B])
loss_ps = (set_p, prob_base, sample_times, sample_vals)
# Define loss function.
function loss(p, (set_p, prob_base, sample_times, sample_vals))
p = set_p(prob_base, p)
newprob = remake(prob_base; p) # Error happens here.
sol = Array(solve(newprob, Rosenbrock23(); saveat = sample_times, verbose = false, maxiters = 10000))
loss = sum(abs2, sol .- sample_vals[:, 1:size(sol,2)])
return loss
end
# Create OptimizationProblem.
optf = OptimizationFunction(loss, Optimization.AutoZygote())
ps_init = [5.0, 5.0]
optprob = OptimizationProblem(optf, ps_init, loss_ps)
# Attempts to solve it.
sol = solve(optprob, ADAM(0.1); maxiters = 100)yields:
ERROR: Tuple field type cannot be Union{}
Stacktrace:
[1] may_bc_derivatives(::Type{Union{}}, f::ModelingToolkitBase.PConstructorApplicator{typeof(identity)}, args::Tuple{})
@ ChainRules C:\Users\Torkel\.julia\packages\ChainRules\14CDN\src\rulesets\Base\broadcast.jl:51
[2] rrule(cfg::Zygote.ZygoteRuleConfig{…}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{…}, f::ModelingToolkitBase.PConstructorApplicator{…}, args::Tuple{})
@ ChainRules C:\Users\Torkel\.julia\packages\ChainRules\14CDN\src\rulesets\Base\broadcast.jl:36
[3] chain_rrule
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\chainrules.jl:234 [inlined]
[4] macro expansion
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:-1 [inlined]
[5] _pullback(::Zygote.Context{…}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{…}, ::ModelingToolkitBase.PConstructorApplicator{…}, ::Tuple{})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
[6] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:1019
[7] adjoint
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\lib.jl:211 [inlined]
[8] _pullback
@ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[9] broadcasted
@ .\broadcast.jl:1346 [inlined]
[10] adjoint
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\broadcast.jl:246 [inlined]
[11] _pullback
@ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[12] fallback_Fix1
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\base.jl:237 [inlined]
[13] _pullback(ctx::Zygote.Context{…}, f::Zygote.var"#fallback_Fix1#fallback_Fix1##0"{…}, args::Tuple{})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
[14] adjoint
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\base.jl:238 [inlined]
[15] _pullback(__context__::Zygote.Context{…}, g::Base.Fix1{…}, y::Tuple{})
@ Zygote C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67
[16] _pullback(::Zygote.Context{…}, ::Base.var"##_#54", ::@Kwargs{}, ::ComposedFunction{…}, ::SymbolicIndexingInterface.ProblemState{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
[17] _apply
@ .\boot.jl:1019 [inlined]
[18] adjoint
@ C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\lib\lib.jl:211 [inlined]
[19] _pullback
@ C:\Users\Torkel\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[20] ComposedFunction
@ .\operators.jl:1096 [inlined]
[21] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::SymbolicIndexingInterface.ProblemState{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
[22] _getter
@ C:\Users\Torkel\.julia\packages\ModelingToolkitBase\B3XVN\src\systems\problem_utils.jl:822 [inlined]
[23] _pullback(::Zygote.Context{…}, ::ModelingToolkitBase.var"#_getter#772"{…}, ::SymbolicIndexingInterface.ProblemState{…}, ::NonlinearProblem{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
[24] _pullback(::Zygote.Context{…}, ::ModelingToolkitBase.ReconstructInitializeprob{…}, ::SymbolicIndexingInterface.ProblemState{…}, ::NonlinearProblem{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
[25] _pullback(::Zygote.Context{…}, ::typeof(SciMLBase.remake_initialization_data), ::System, ::ODEFunction{…}, ::Missing, ::Float64, ::MTKParameters{…}, ::Vector{…}, ::MTKParameters{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
[26] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##remake#869", ::Missing, ::Missing, ::Missing, ::MTKParameters{…}, ::Missing, ::Bool, ::Type{…}, ::Bool, ::Nothing, ::@Kwargs{}, ::typeof(remake), ::ODEProblem{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:81
[27] remake
@ C:\Users\Torkel\.julia\packages\SciMLBase\tmJhj\src\remake.jl:222 [inlined]
[28] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(remake), ::ODEProblem{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
[29] loss
@ .\Untitled-1:37 [inlined]
[30] _pullback(::Zygote.Context{…}, ::typeof(loss), ::Vector{…}, ::Tuple{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface2.jl:0
[31] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:96
[32] pullback(::Function, ::Vector{…}, ::Tuple{…})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:94
[33] withgradient(::Function, ::Vector{Float64}, ::Vararg{Any})
@ Zygote C:\Users\Torkel\.julia\packages\Zygote\55SqB\src\compiler\interface.jl:211
[34] value_and_gradient
@ C:\Users\Torkel\.julia\packages\DifferentiationInterface\M8gIf\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:115 [inlined]
[35] value_and_gradient!(f::Function, grad::Vector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::Vector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt C:\Users\Torkel\.julia\packages\DifferentiationInterface\M8gIf\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:131
[36] (::OptimizationZygoteExt.var"#fg!#19"{…})(res::Vector{…}, θ::Vector{…})
@ OptimizationZygoteExt C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\ext\OptimizationZygoteExt.jl:57
[37] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers C:\Users\Torkel\.julia\packages\OptimizationOptimisers\atr4L\src\OptimizationOptimisers.jl:83
[38] solve!(cache::OptimizationCache{…})
@ OptimizationBase C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\src\solve.jl:216
[39] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
@ OptimizationBase C:\Users\Torkel\.julia\packages\OptimizationBase\mYxHK\src\solve.jl:98
[40] top-level scope
@ Untitled-1:50
Some type information was truncated. Use `show(err)` to see complete types.
Noteably, all these works:
loss(ps_init, loss_ps)
optf(ps_init, loss_ps)
optprob.f(ps_init, loss_ps)Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working