-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No deriative rule found for struct constructor #117
Comments
It's a bug. Yota thinks |
Just to note, I was also seeing this in the latest release. But when I cloned the repository in order to add some |
This seems to work on using Optimisers, ChainRulesCore
function gradient(f, xs...)
println("Yota gradient!")
_, g = Yota.grad(f, xs...)
g[2:end]
end;
# a quick hack, not really tested
function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
# a more robust implementation would be to do backword conversion:
# return x, Δ -> (NoTangent(), NoTangent(), convert(typeof(x), Δ))
# but it doesn't work for ZeroTangent(), so passing Δ as is
return x, Δ -> (NoTangent(), NoTangent(), Δ)
end
m1 = collect(1:3.0);
gradient(m -> destructure(m)[1][1], m1)[1] |
I get the same error on latest Yota + Umlat. The rule for julia> gradient((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0])
Yota gradient!
ERROR: MethodError: no method matching _rebuild(::Vector{Float64}, ::Int64, ::ZeroTangent, ::Int64; walk::typeof(Optimisers._Tangent_biwalk), prune::NoTangent)
Closest candidates are:
_rebuild(::Any, ::Any, ::AbstractVector, ::Any; walk, kw...)
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
_rebuild(::Any, ::Any, ::AbstractVector) got unsupported keyword arguments "walk", "prune"
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:82
Stacktrace:
[1] (::Optimisers.var"#_flatten_back#18"{Vector{Float64}, Int64, Int64})(::Tangent{Tuple{Vector{Float64}, Int64, Int64}, Tuple{ZeroTangent, NoTangent, NoTangent}})
@ Optimisers ~/.julia/packages/Optimisers/AqvxP/src/destructure.jl:77
[2] mkcall(fn::Umlaut.Variable, args::Umlaut.Variable; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[3] mkcall
@ ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179 [inlined]
[4] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:164 cf Zygote:
|
It looks like a different error, tracing constructors of julia> trace((m,v) -> destructure(m)[2](v)[1], m1, [1,2,3.0]; ctx=GradCtx())
(1.0, Tape{GradCtx}
inp %1::var"#212#213"
inp %2::Vector{Float64}
inp %3::Vector{Float64}
%5, %6 = [%4] = rrule(YotaRuleConfig(), _flatten, %2)
%8, %9 = [%7] = rrule(YotaRuleConfig(), indexed_iterate, %5, 1)
%11, %12 = [%10] = rrule(YotaRuleConfig(), getfield, %8, 1)
%14, %15 = [%13] = rrule(YotaRuleConfig(), getfield, %8, 2)
%17, %18 = [%16] = rrule(YotaRuleConfig(), indexed_iterate, %5, 2, %14)
%20, %21 = [%19] = rrule(YotaRuleConfig(), getfield, %17, 1)
%23, %24 = [%22] = rrule(YotaRuleConfig(), getfield, %17, 2)
%26, %27 = [%25] = rrule(YotaRuleConfig(), indexed_iterate, %5, 3, %23)
%29, %30 = [%28] = rrule(YotaRuleConfig(), getfield, %26, 1)
%32, %33 = [%31] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
%35, %36 = [%34] = rrule(YotaRuleConfig(), apply_type, Optimisers.Restructure, Vector{Float64}, Int64)
%38, %39 = [%37] = rrule(YotaRuleConfig(), convert, Vector{Float64}, %2)
%41, %42 = [%40] = rrule(YotaRuleConfig(), convert, Int64, %20)
%44, %45 = [%43] = rrule(YotaRuleConfig(), fieldtype, %35, 3)
%47, %48 = [%46] = rrule(YotaRuleConfig(), convert, %44, %29)
%50, %51 = [%49] = rrule(YotaRuleConfig(), __new__, %35, %38, %41, %47) # <-- this is internal constuctor of Restructure
%53, %54 = [%52] = rrule(YotaRuleConfig(), tuple, %11, %50)
%56, %57 = [%55] = rrule(YotaRuleConfig(), getindex, %53, 2)
%59, %60 = [%58] = rrule(YotaRuleConfig(), getproperty, %56, model)
%62, %63 = [%61] = rrule(YotaRuleConfig(), getproperty, %56, offsets)
%65, %66 = [%64] = rrule(YotaRuleConfig(), getproperty, %56, length)
%68, %69 = [%67] = rrule(YotaRuleConfig(), _rebuild, %59, %62, %3, %65)
%71, %72 = [%70] = rrule(YotaRuleConfig(), getindex, %68, 1)
)
|
It's not impossible there are bugs in Trying to find a simpler example of what I thought was the original problem, with a struct from here: julia> using Yota, ChainRulesCore
julia> struct Multiplier{T} # from test_helpers in ChainRules
x::T
end
julia> (m::Multiplier)(y) = m.x * y
julia> function ChainRulesCore.rrule(m::Multiplier, y)
Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
return m(y), Multiplier_pullback
end
julia> grad(x -> x(3.0), Multiplier(5.0)) # perfect
(15.0, (ZeroTangent(), Tangent{Multiplier{Float64}}(x = 3.0,)))
julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %3 = Multiplier(%2)::Multiplier{Float64}, try defining it using
ChainRulesCore.rrule(::UnionAll, ::Float64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
...
julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
inp %1::var"#8#9"
inp %2::Float64
%3 = Multiplier(%2)::Multiplier{Float64}
%5, %6 = [%4] = rrule(Yota.YotaRuleConfig(), %3, 3.0)
)
(jl_lUY4C1) pkg> st Yota
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lUY4C1/Project.toml`
[cd998857] Yota v0.7.3 That's the tagged version. On master: (jl_lpE53K) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_lpE53K/Project.toml`
[92992a2b] Umlaut v0.2.5 `https://github.com/dfdx/Umlaut.jl.git#main`
[cd998857] Yota v0.7.4 `https://github.com/dfdx/Yota.jl.git#main`
julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: No deriative rule found for op %9 = convert(%7, %2)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:170
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:211
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/98gNT/src/grad.jl:222
...
julia> Yota.trace(x -> Multiplier(x)(3.0), 5.0; ctx=Yota.GradCtx())
(15.0, Tape{Yota.GradCtx}
inp %1::var"#14#15"
inp %2::Float64
%4, %5 = [%3] = rrule(Yota.YotaRuleConfig(), apply_type, Multiplier, Float64)
%7, %8 = [%6] = rrule(Yota.YotaRuleConfig(), fieldtype, %4, 1)
%9 = convert(%7, %2)::Float64
%11, %12 = [%10] = rrule(Yota.YotaRuleConfig(), __new__, %4, %9)
%14, %15 = [%13] = rrule(Yota.YotaRuleConfig(), %11, 3.0)
)
julia> function ChainRulesCore.rrule(::typeof(convert), ::DataType, x)
# Version with re-conversion via ProjectTo? Maybe this is only right for number types...
# Also, why not convert on the forward pass?
return x, Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(Δ))
end
julia> grad(x -> Multiplier(x)(3.0), 5.0)
(15.0, (ZeroTangent(), 3.0)) That's why your rule targets |
Sorry for the silence - I've been working on some bug fixes and improvements that may affect this question too. In particular, I added lineinfo to call nodes, and here's what it shows: struct Multiplier{T} # from test_helpers in ChainRules
x::T
end
(m::Multiplier)(y) = m.x * y
function ChainRulesCore.rrule(m::Multiplier, y)
Multiplier_pullback(dΩ) = (Tangent{typeof(m)}(; x = dΩ * y'), m.x' * dΩ)
return m(y), Multiplier_pullback
end
mult1(x) = x(3.0)
mult2(x) = Multiplier(x)(3.0)
_, tape = trace(mult2, 5.0; ctx=GradCtx()) Result: (15.0, Tape{GradCtx}
inp %1::typeof(mult2)
inp %2::Float64
%4, %5 = [%3] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%7, %8 = [%6] = rrule(YotaRuleConfig(), apply_type, Multiplier, Float64) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%9 = convert(Float64, %2)::Float64 # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%11, %12 = [%10] = rrule(YotaRuleConfig(), __new__, %7, %9) # Main.Multiplier at /home/azbs/work/Yota/src/_main3.jl:9
%14, %15 = [%13] = rrule(YotaRuleConfig(), %11, 3.0) # Main.mult2 at /home/azbs/work/Yota/src/_main3.jl:37
) So function ChainRulesCore.rrule(::typeof(convert), ::Type{T}, x::T) where T
return x, Δ -> (NoTangent(), NoTangent(), Δ)
end |
Both the top example and the Multiplier one work on Yota 0.8 and Julia 1.8, which is great. On Julia nightly, something seems to go wrong, perhaps of interest (and might be why I saw errors in FluxML/Optimisers.jl#105): julia> grad(x -> Multiplier(x)(3.0), 5.0)
ERROR: Unexpected expression: $(Expr(:static_parameter, 1))
Full IRCode:
2 1 ─ %4 = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│ %1 = Core.apply_type(Main.Multiplier, %4)::Core.Const(Multiplier{Float64})
│ %2 = (%1)(_2)::Multiplier{Float64}
└── return %2
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:333
[3] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Tuple{UnionAll, Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
[4] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Type, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:290
[5] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:315
[6] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:439
[7] trace(f::Function, args::Float64; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/trace.jl:556
[8] gradtape(f::Function, args::Float64; ctx::Yota.GradCtx, seed::Int64)
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:291 |
Apparently, Julia 1.9 changes the way static parameters (i.e. type parameters, (No changes to the version of Yota itself are needed at the moment) |
Great, thanks. Then I mark this as closed. |
I'm surprised by this error, which if I understand right comes from this line https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl#L31 constructing a struct which isn't in fact used. Is this the desired behaviour, or can all (default?) constructors be handled automatically somehow?
Xref FluxML/Optimisers.jl#96
The text was updated successfully, but these errors were encountered: