Skip to content

derivative of derivative does not work #135

@MariusDrulea

Description

@MariusDrulea

Taking the derivative of derivative with Yota gives an error. See the following MWE.

using Yota

# the function of interest
f(x, y, z) = x*y*z

# yota gradients
fx(x, y, z) = grad(x->f(x, y, z), x)[2][2]
fxy(x, y, z) = grad(y->fx(x, y, z), y)[2][2]
fxyz(x, y, z) = grad(z->fxy(x, y, z), z)

fx(1.f0, 2.f0, 3.f0) # this works, output is 6.f0
fxy(1.f0, 2.f0, 3.f0) # error
fxyz(1.f0, 2.f0, 3.f0) 

Stacktrace:

ERROR: No derivative rule found for op %69 = %62(%59, %63, %50, %20, %4)::Tuple{Float32, Tuple{ChainRulesCore.Tangent{var"#13#14"{Float32, Float32}, NamedTuple{(:y, :z), Tuple{Float32, Float32}}}, Float32}} , try defining it using

        ChainRulesCore.rrule(::Base.var"#invokelatest##kw", ::NamedTuple{(:seed,), Tuple{Int64}}, ::typeof(Base.invokelatest), ::Yota.var"###tape_#13#316", ::var"#13#14"{Float32, Float32}, ::Float32) = ...

Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:178
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:220
 [4] #gradtape!#77
   @ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:245 [inlined]
 [5] gradtape(f::Function, args::Float32; ctx::Yota.GradCtx, seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:264
 [6] grad(f::Function, args::Float32; seed::Int64)
   @ Yota D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:356
 [7] grad
   @ D:\z_installed_programs\julia-depot\packages\Yota\G3nBA\src\grad.jl:348 [inlined]
 [8] fxy(x::Float32, y::Float32, z::Float32) grad_grad_yota.jl:8
 [9] top-level scope grad_grad_yota.jl:12

What I additionally remarked is that fx returns Any instead of Float32, see here:

@code_warntype fx(1.f0, 2.f0, 3.f0)

MethodInstance for fx(::Float32, ::Float32, ::Float32)
  from fx(x, y, z) in grad_grad_yota.jl:7
Arguments
  #self#::Core.Const(fx)
  x::Float32
  y::Float32
  z::Float32
Locals
  #23::var"#23#24"{Float32, Float32}
Body::Any
1 ─ %1 = Main.:(var"#23#24")::Core.Const(var"#23#24")
│   %2 = Core.typeof(y)::Core.Const(Float32)
│   %3 = Core.typeof(z)::Core.Const(Float32)
│   %4 = Core.apply_type(%1, %2, %3)::Core.Const(var"#23#24"{Float32, Float32})
│        (#23 = %new(%4, y, z))
│   %6 = #23::var"#23#24"{Float32, Float32}
│   %7 = Main.grad(%6, x)::Any
│   %8 = Base.getindex(%7, 2)::Any
│   %9 = Base.getindex(%8, 2)::Any
└──      return %9

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions