-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Zygote doesn't interact too nicely with LazyArrays.jl it seems, e.g.:
julia> f(x) = sum(BroadcastArray(exp, x))
f (generic function with 1 method)
julia> Zygote.gradient(f, randn(10))
ERROR: type Array has no field f
Stacktrace:
[1] adjoint
@ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:229 [inlined]
[2] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[3] _pullback
@ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:50 [inlined]
[4] _pullback(::Zygote.Context{false}, ::typeof(LazyArrays.call), ::ArrayLayouts.DenseColumnMajor, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
[5] _pullback
@ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:52 [inlined]
[6] _pullback
@ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:82 [inlined]
[7] _pullback
@ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:57 [inlined]
[8] _pullback(::Zygote.Context{false}, ::Type{BroadcastArray}, ::typeof(exp), ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
[9] _pullback
@ ./REPL[48]:1 [inlined]
[10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
[11] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
[12] pullback
@ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
[14] top-level scope
@ REPL[50]:1
julia> g(x) = sum(LazyArray(@~ exp.(x)))
g (generic function with 1 method)
julia> Zygote.gradient(g, randn(10))
ERROR: MethodError: no method matching LazyArray(::Vector{Float64})
Closest candidates are:
LazyArray(::Base.Broadcast.Broadcasted) at ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:35
LazyArray(::Applied) at ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:193
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context{false}, f::Type{LazyArray}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:9
[3] _pullback
@ ./REPL[53]:1 [inlined]
[4] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
[5] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
[6] pullback
@ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
[7] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
[8] top-level scope
@ REPL[54]:1
The first error can be "fixed" (I'm not entirely certain if this is the right way to go about it) by defining a chain rule:
julia> using ChainRulesCore
julia> function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
end
julia> Zygote.refresh()
julia> Zygote.gradient(f, randn(10))
([0.24117702568683322, 2.478340448616497, 2.433266795642693, 1.6163793920298133, 1.8859252985478665, 3.9539878829654223, 1.2578105524502685, 0.48545348574922, 0.8710494256114425, 3.0853524634917076],)
Maybe the rest can be addressed this way too.
Are rules from CRC something that would be welcomed?
Metadata
Metadata
Assignees
Labels
No labels