Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No deriative rule found for struct constructor #117

Closed
mcabbott opened this issue Jul 7, 2022 · 10 comments
Closed

No deriative rule found for struct constructor #117

mcabbott opened this issue Jul 7, 2022 · 10 comments

Comments

@mcabbott
Copy link
Contributor

mcabbott commented Jul 7, 2022

I'm surprised by this error, which if I understand right comes from this line https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl#L31 constructing a struct which isn't in fact used. Is this the desired behaviour, or can all (default?) constructors be handled automatically somehow?

julia> using Yota, Optimisers, ChainRulesCore

julia> function gradient(f, xs...)
         println("Yota gradient!")
         _, g = Yota.grad(f, xs...)
         g[2:end]
       end;

julia> m1 = collect(1:3.0);

julia> gradient(m -> destructure(m)[1][1], m1)[1]
Yota gradient!
ERROR: No deriative rule found for op %30 = Optimisers.Restructure(%2, %19, %28)::Optimisers.Restructure{Vector{Float64}, Int64}, try defining it using 

	ChainRulesCore.rrule(::UnionAll, ::Vector{Float64}, ::Int64, ::Int64) = ...

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/dev/Yota/src/grad.jl:170
  [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/dev/Yota/src/grad.jl:211

julia> function ChainRulesCore.rrule(T::Type{<:Optimisers.Restructure}, v, i, j)
         back(tan) = (NoTangent(), tan.model, tan.offsets, tan.length)
         back(z::AbstractZero) = (z,z,z,z)
         T(v,i,j), back
       end

julia> gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
Yota gradient!
true

Xref FluxML/Optimisers.jl#96

@dfdx
Copy link
Owner

dfdx commented Jul 7, 2022

It's a bug. Yota thinks Optimisers.Restucture(...) is a primitive and records it to the tape instead of tracing it down to (:new, T, args...) (represented as __new__(T, args...) on the tape). Yota thinks it's a primitive because typeof(Restructure) returns UnionAll, which belongs to module Core, and we don't trace deeper than that. I have an idea how to fix it, will try to implement it tonight.

@cscherrer
Copy link
Contributor

Just to note, I was also seeing this in the latest release. But when I cloned the repository in order to add some @show statements for debugging, the problem went away (well, actually it was replaced by a different bug). @mcabbott are you seeing the same behavior in master?

@dfdx
Copy link
Owner

dfdx commented Jul 9, 2022

This seems to work on main:

using Optimisers, ChainRulesCore


function gradient(f, xs...)
    println("Yota gradient!")
    _, g = Yota.grad(f, xs...)
    g[2:end]
end;


# a quick hack, not really tested
function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
    # a more robust implementation would be to do backword conversion:
    #   return x, Δ -> (NoTangent(), NoTangent(), convert(typeof(x), Δ))
    # but it doesn't work for ZeroTangent(), so passing Δ as is
    return x, Δ -> (NoTangent(), NoTangent(), Δ)
end


m1 = collect(1:3.0);
gradient(m -> destructure(m)[1][1], m1)[1]

@mcabbott
Copy link
Contributor Author

I get the same error on latest Yota + Umlat.

The rule for convert silences the error, but doesn't actually make the struct requested. It isn't used above, but if I change the code to something which does use that part, it fails:

julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
Yota gradient!
ERROR: MethodError: no method matching _rebuild(::Vector{Float64}, ::Int64, ::ZeroTangent, ::Int64; walk::typeof(Optimisers._Tangent_biwalk), prune::NoTangent)

Closest candidates are:
  _rebuild(::Any, ::Any, ::AbstractVector, ::Any; walk, kw...)
   @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
  _rebuild(::Any, ::Any, ::AbstractVector) got unsupported keyword arguments "walk", "prune"
   @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82

Stacktrace:
  [1] (::Optimisers.var"#_flatten_back#18"{Vector{Float64}, Int64, Int64})(::Tangent{Tuple{Vector{Float64}, Int64, Int64}, Tuple{ZeroTangent, NoTangent, NoTangent}})
    @ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:77
  [2] mkcall(fn::Umlaut.Variable, args::Umlaut.Variable; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
  [3] mkcall
    @ ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179 [inlined]
  [4] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:164

cf Zygote:

julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
(nothing, [1.0, 0.0, 0.0])

@dfdx
Copy link
Owner

dfdx commented Jul 10, 2022

It looks like a different error, tracing constructors of UnionAll, which caused the previous error, works correctly this time:

julia> trace((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0]; ctx=GradCtx())
(1.0, Tape{GradCtx}
  inp %1::var"#212#213"
  inp %2::Vector{Float64}
  inp %3::Vector{Float64}
  %5, %6 = [%4] = rrule(YotaRuleConfig(), _flatten, %2)
  %8, %9 = [%7] = rrule(YotaRuleConfig(), indexed_iterate, %5, 1)
  %11, %12 = [%10] = rrule(YotaRuleConfig(), getfield, %8, 1)
  %14, %15 = [%13] = rrule(YotaRuleConfig(), getfield, %8, 2)
  %17, %18 = [%16] = rrule(YotaRuleConfig(), indexed_iterate, %5, 2, %14)
  %20, %21 = [%19] = rrule(YotaRuleConfig(), getfield, %17, 1)
  %23, %24 = [%22] = rrule(YotaRuleConfig(), getfield, %17, 2)
  %26, %27 = [%25] = rrule(YotaRuleConfig(), indexed_iterate, %5, 3, %23)
  %29, %30 = [%28] = rrule(YotaRuleConfig(), getfield, %26, 1)
  %32, %33 = [%31] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
  %35, %36 = [%34] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
  %38, %39 = [%37] = rrule(YotaRuleConfig(), convert, Vector{Float64}, %2)
  %41, %42 = [%40] = rrule(YotaRuleConfig(), convert, Int64, %20)
  %44, %45 = [%43] = rrule(YotaRuleConfig(), fieldtype, %35, 3)
  %47, %48 = [%46] = rrule(YotaRuleConfig(), convert, %44, %29)
  %50, %51 = [%49] = rrule(YotaRuleConfig(), __new__, %35, %38, %41, %47)    # <-- this is internal constuctor of Restructure
  %53, %54 = [%52] = rrule(YotaRuleConfig(), tuple, %11, %50)
  %56, %57 = [%55] = rrule(YotaRuleConfig(), getindex, %53, 2)
  %59, %60 = [%58] = rrule(YotaRuleConfig(), getproperty, %56, model)
  %62, %63 = [%61] = rrule(YotaRuleConfig(), getproperty, %56, offsets)
  %65, %66 = [%64] = rrule(YotaRuleConfig(), getproperty, %56, length)
  %68, %69 = [%67] = rrule(YotaRuleConfig(), _rebuild, %59, %62, %3, %65)
  %71, %72 = [%70] = rrule(YotaRuleConfig(), getindex, %68, 1)
)

_rebuild() doesn't accept ZeroTangent() as a cotangent value. A real question is whether ZeroTangent() is correct here, and if so, why Zygote doesn't hit the same problem. I will need to understand more about Optimisers internals and the generated graph to answer these questions.

@mcabbott
Copy link
Contributor Author

It's not impossible there are bugs in _rebuild, sorry, it's pretty messy. Will take a look, at some point.

Trying to find a simpler example of what I thought was the original problem, with a struct from here:

julia> using Yota, ChainRulesCore

julia> struct Multiplier{T}  # from test_helpers in ChainRules
           x::T
       end

julia> (m::Multiplier)(y) = m.x * y

julia> function ChainRulesCore.rrule(m::Multiplier, y)
           Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x =* y'), m.x' * dΩ)
           return m(y), Multiplier_pullback
       end

julia> grad(x -> x(3.0), Multiplier(5.0))  # perfect
(15.0, (ZeroTangent(), Tangent{Multiplier{Float64}}(x = 3.0,)))

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %3 = Multiplier(%2)::Multiplier{Float64}, try defining it using 

	ChainRulesCore.rrule(::UnionAll, ::Float64) = ...

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
...

julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
  inp %1::var"#8#9"
  inp %2::Float64
  %3 = Multiplier(%2)::Multiplier{Float64}
  %5, %6 = [%4] = rrule(Yota.YotaRuleConfig(), %3, 3.0)
)

(jl_lUY4C1) pkg> st Yota
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lUY4C1/Project.toml`
  [cd998857] Yota v0.7.3

That's the tagged version. On master:

(jl_lpE53K) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lpE53K/Project.toml`
  [92992a2b] Umlaut v0.2.5 `https://github.com/dfdx/Umlaut.jl.git#main`
  [cd998857] Yota v0.7.4 `https://github.com/dfdx/Yota.jl.git#main`

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %9 = convert(%7, %2)::Float64, try defining it using 

	ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:170
 [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:211
 [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
   @ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:222
...

julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
  inp %1::var"#14#15"
  inp %2::Float64
  %4, %5 = [%3] = rrule(Yota.YotaRuleConfig(), apply_type, Multiplier, Float64)
  %7, %8 = [%6] = rrule(Yota.YotaRuleConfig(), fieldtype, %4, 1)
  %9 = convert(%7, %2)::Float64
  %11, %12 = [%10] = rrule(Yota.YotaRuleConfig(), __new__, %4, %9)
  %14, %15 = [%13] = rrule(Yota.YotaRuleConfig(), %11, 3.0)
)

julia> function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
           # Version with re-conversion via ProjectTo? Maybe this is only right for number types...
           # Also, why not convert on the forward pass?
           return x, Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(Δ))
       end

julia> grad(x -> Multiplier(x)(3.0), 5.0)
(15.0, (ZeroTangent(), 3.0))

That's why your rule targets convert, since apply_type is now being applied. But why is convert being called at all?

@dfdx
Copy link
Owner

dfdx commented Aug 8, 2022

Sorry for the silence - I've been working on some bug fixes and improvements that may affect this question too. In particular, I added lineinfo to call nodes, and here's what it shows:

struct Multiplier{T}  # from test_helpers in ChainRules
    x::T
end

(m::Multiplier)(y) = m.x * y


function ChainRulesCore.rrule(m::Multiplier, y)
    Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x =* y'), m.x' * dΩ)
    return m(y), Multiplier_pullback
end

mult1(x) = x(3.0)
mult2(x) = Multiplier(x)(3.0)

_, tape = trace(mult2, 5.0; ctx=GradCtx())

Result:

(15.0, Tape{GradCtx}
  inp %1::typeof(mult2)
  inp %2::Float64
  %4, %5 = [%3] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64)              # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %7, %8 = [%6] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64)              # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %9 = convert(Float64, %2)::Float64            # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %11, %12 = [%10] = rrule(YotaRuleConfig(), __new__, %7, %9)           # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
  %14, %15 = [%13] = rrule(YotaRuleConfig(), %11, 3.0)          # Main.mult2 at /home/azbs/work/Yota/src/_main3.jl:37
)

So convert(Float64, %2) happens in the object constructor, even though %2 is already Float64. I tried to trick the compiler not to add convert(), but it seems to be just an essential detail of the lowered code. Nevertheless, we can now simplify the rrule for convert to a more strict version:

function ChainRulesCore.rrule(::typeof(convert), ::Type{T}, x::T) where T
    return x, Δ -> (NoTangent(), NoTangent(), Δ)
end

@mcabbott
Copy link
Contributor Author

Both the top example and the Multiplier one work on Yota 0.8 and Julia 1.8, which is great.

On Julia nightly, something seems to go wrong, perhaps of interest (and might be why I saw errors in FluxML/Optimisers.jl#105):

julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: Unexpected expression: $(Expr(:static_parameter, 1))
Full IRCode:

 2 1%4 = $(Expr(:static_parameter, 1))::Core.Const(Float64)
  │   %1 = Core.apply_type(Main.Multiplier, %4)::Core.Const(Multiplier{Float64})
  │   %2 = (%1)(_2)::Multiplier{Float64}
  └──      return %2
  
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:333
  [3] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Tuple{UnionAll, Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
  [4] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Type, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:290
  [5] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:315
  [6] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
  [7] trace(f::Function, args::Float64; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:556
  [8] gradtape(f::Function, args::Float64; ctx::Yota.GradCtx, seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:291

@dfdx
Copy link
Owner

dfdx commented Aug 27, 2022

Apparently, Julia 1.9 changes the way static parameters (i.e. type parameters, {T}) are used in IRCode. You can update Umlaut to 0.4.5 to account for this.

(No changes to the version of Yota itself are needed at the moment)

@mcabbott
Copy link
Contributor Author

Great, thanks. Then I mark this as closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants