Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spurious InexactError #492

Open
seadra opened this issue Feb 17, 2021 · 2 comments
Open

Spurious InexactError #492

seadra opened this issue Feb 17, 2021 · 2 comments

Comments

@seadra
Copy link

seadra commented Feb 17, 2021

Running the following code

const T = 10.0;
const ω = π/T;

ann = FastChain(FastDense(1,32,tanh), FastDense(32,32,tanh), FastDense(32,1))
ip = initial_params(ann);

function f_nn(u, p, t)
    a = ann([t],p)[1];
    A = [1.0 a; a -1.0];
    return -im*A*u;
end


u0 = [Complex{Float64}(1) 0; 0 1];

tspan = (0.0, T)


prob_ode = ODEProblem(f_nn, u0, tspan, ip);
sol_ode = solve(prob_ode, Tsit5());


utarget = [Complex{Float64}(0) im; im 0];

function predict_adjoint(p)
  return solve(prob_ode, Tsit5(), p=p, abstol=1e-12, reltol=1e-12)
end

function loss_adjoint(p)
    prediction = predict_adjoint(p[1:end-1])
    usol = last(prediction)
    x = p[end]
    r = [cos(x) sin(x); -sin(x) cos(x)];
    loss = 1.0 - abs(tr(r*usol*utarget')/2)^2
    return loss
end

DiffEqFlux.sciml_train(loss_adjoint, [ip;0.0], ADAM(0.1), maxiters = 100)

results in the following error

InexactError: Float64(-0.03248179857510965 + 1.7347234759768067e-19im)

Stacktrace:
 [1] Real at ./complex.jl:37 [inlined]
 [2] convert at ./number.jl:7 [inlined]
 [3] setindex! at ./array.jl:847 [inlined]
 [4] macro expansion at ./broadcast.jl:932 [inlined]
 [5] macro expansion at ./simdloop.jl:77 [inlined]
 [6] copyto! at ./broadcast.jl:931 [inlined]
 [7] copyto! at ./broadcast.jl:886 [inlined]
 [8] materialize! at ./broadcast.jl:848 [inlined]
 [9] materialize!(::Array{Float64,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{Float64,Array{Float64,1}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(-),Tuple{Int64,Float64}},Array{Complex{Float64},1}}}}}) at ./broadcast.jl:845
 [10] apply!(::ADAM, ::Array{Float64,1}, ::Array{Complex{Float64},1}) at /home/user/.julia/packages/Flux/sY3yx/src/optimise/optimisers.jl:175
 [11] update!(::ADAM, ::Array{Float64,1}, ::Array{Complex{Float64},1}) at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:19
 [12] update!(::ADAM, ::Params, ::Zygote.Grads) at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:29
 [13] macro expansion at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:131 [inlined]
 [14] macro expansion at /home/user/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [15] (::DiffEqFlux.var"#73#78"{DiffEqFlux.var"#77#82",Int64,Bool,Bool,typeof(loss_adjoint),Array{Float64,1},Params})() at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
 [16] with_logstate(::Function, ::Any) at ./logging.jl:408
 [17] with_logger at ./logging.jl:514 [inlined]
 [18] maybe_with_logger(::DiffEqFlux.var"#73#78"{DiffEqFlux.var"#77#82",Int64,Bool,Bool,typeof(loss_adjoint),Array{Float64,1},Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{ConsoleProgressMonitor.ProgressLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Base.CoreLogging.SimpleLogger,DiffEqFlux.var"#69#71"}}}) at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
 [19] sciml_train(::Function, ::Array{Float64,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/user/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
 [20] top-level scope at In[9]:38

I tracked the error down to this portion:

    x = p[end]
    r = [cos(x) sin(x); -sin(x) cos(x)];
    loss = 1.0 - abs(tr(r*usol*utarget')/2)^2

If we remove r from the loss, such that loss = 1.0 - abs(tr(usol*utarget')/2)^2, the problem disappears. Alternatively, if we eliminate either sin(x) or cos(x) from the definition of r, like r = [cos(x) 1; -1 cos(x)] the problem also disappears.

@ChrisRackauckas
Copy link
Member

Seems like a Zygote issue? Is there a way to reproduce it without the ODE solver in there?

@seadra
Copy link
Author

seadra commented Feb 20, 2021

I couldn't reproduce it without the ODE.

When I change the loss function to

function loss_adjoint(p)
    x = p[end]
    r = [cos(x) sin(x); -sin(x) cos(x)];
    loss = 1.0 - abs(tr(r*utarget')/2)^2
    return loss
end

it runs successfully. It's puzzling that it also runs if I change the definition of x as abs(p[end])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants