-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
Taking the derivative of derivative with Yota gives an error. See the following MWE.
using Yota
# the function of interest
f(x, y, z) = x*y*z
# yota gradients
fx(x, y, z) = grad(x->f(x, y, z), x)[2][2]
fxy(x, y, z) = grad(y->fx(x, y, z), y)[2][2]
fxyz(x, y, z) = grad(z->fxy(x, y, z), z)
fx(1.f0, 2.f0, 3.f0) # this works, output is 6.f0
fxy(1.f0, 2.f0, 3.f0) # error
fxyz(1.f0, 2.f0, 3.f0)
Stacktrace:
ERROR: No derivative rule found for op %69 = %62(%59, %63, %50, %20, %4)::Tuple{Float32, Tuple{ChainRulesCore.Tangent{var"#13#14"{Float32, Float32}, NamedTuple{(:y, :z), Tuple{Float32, Float32}}}, Float32}} , try defining it using
ChainRulesCore.rrule(::Base.var"#invokelatest##kw", ::NamedTuple{(:seed,), Tuple{Int64}}, ::typeof(Base.invokelatest), ::Yota.var"###tape_#13#316", ::var"#13#14"{Float32, Float32}, ::Float32) = ...
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:178
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:220
[4] #gradtape!#77
@ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:245 [inlined]
[5] gradtape(f::Function, args::Float32; ctx::Yota.GradCtx, seed::Int64)
@ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:264
[6] grad(f::Function, args::Float32; seed::Int64)
@ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:356
[7] grad
@ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:348 [inlined]
[8] fxy(x::Float32, y::Float32, z::Float32) grad_grad_yota.jl:8
[9] top-level scope grad_grad_yota.jl:12
What I additionally remarked is that fx returns Any
instead of Float32, see here:
@code_warntype fx(1.f0, 2.f0, 3.f0)
MethodInstance for fx(::Float32, ::Float32, ::Float32)
from fx(x, y, z) in grad_grad_yota.jl:7
Arguments
#self#::Core.Const(fx)
x::Float32
y::Float32
z::Float32
Locals
#23::var"#23#24"{Float32, Float32}
Body::Any
1 ─ %1 = Main.:(var"#23#24")::Core.Const(var"#23#24")
│ %2 = Core.typeof(y)::Core.Const(Float32)
│ %3 = Core.typeof(z)::Core.Const(Float32)
│ %4 = Core.apply_type(%1, %2, %3)::Core.Const(var"#23#24"{Float32, Float32})
│ (#23 = %new(%4, y, z))
│ %6 = #23::var"#23#24"{Float32, Float32}
│ %7 = Main.grad(%6, x)::Any
│ %8 = Base.getindex(%7, 2)::Any
│ %9 = Base.getindex(%8, 2)::Any
└── return %9
Metadata
Metadata
Assignees
Labels
No labels