Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test with Yota, too #105

Merged
merged 8 commits into from
Dec 8, 2022
Merged

Test with Yota, too #105

merged 8 commits into from
Dec 8, 2022

Conversation

mcabbott
Copy link
Member

Does not close #96, in fact this surely makes tests slower. But perhaps it's good to get something besides Zygote running?

Copy link
Member Author

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paging @dfdx about these errors.

test/destructure.jl Outdated Show resolved Hide resolved
test/destructure.jl Outdated Show resolved Hide resolved
Comment on lines 92 to 100
Unfortunately this example doesn't actually run right now. This is the error:
```
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should stay WIP for a bit.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pinging me! I'll be able to check out these errors during the weekend.

@mcabbott mcabbott marked this pull request as draft August 19, 2022 20:28
@dfdx
Copy link

dfdx commented Aug 21, 2022

Some of the broken tests are already fixed on main, others need some adjustment (e.g. ZeroTangent() vs. NoTangent()), but I think I'll be able to fix them in a couple of days.

One think that I seem to be missing is why destructure/Restructure needs to be differentiable. I'd expect a training loop to like this:

model = MyModel()
state = Optimisers.setup(Optimisers.Adam(), model) 
input = ...
loss = ...
for i=1:N
    grad = gradient(loss, model, input)                                                           # differentiable part
    state, model = Optimisers.update(state, model, grad)  # at every step      # non necessarily differentiable
end

state points to the trainable parameters of MyModel() and lets us update them, but never steps into gradient calculation. Yet, you test things like Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1], so my picture of the world is obviously incomplete.

@mcabbott
Copy link
Member Author

Sounds good. I have no idea if the tests have ZeroTangent() vs. NoTangent() the wrong way around, fine to adjust tests to whatever is produced.

I broke Flux at some point because it turned out half the SciML universe rested on the gradient of destructrure, and there were exactly zero tests... it was so rudimentary I assumed it was for saving to CSV or similar use only. But I think it gets used as an interface between things which don't like nested structures (like calling some package for LBFGS) and models which do. Or obtaining a Hessian of the parameters of some Flux model.

@dfdx
Copy link

dfdx commented Aug 24, 2022

I have a question regarding tests like this:

@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])

Currently, Yota returns Tangent{Tuple{Vector{Float64}, Vector{Float64}}}([0.0, 1.0, 0.0], [0.0, 0.0, 0.0]) so can do e.g.:

g = Yota_gradient(m -> destructure(m)[1][2], m2)[1]
g + [0, 0, 0]     # =>     [0.0, 1.0, 0.0]

Is it what Optimisers.jl expects or should I better return a plain tuple as in the test?

@mcabbott
Copy link
Member Author

I think a Tangent is fine, this term is the gradient with respect to a Tuple.

The test should be changed to allow for this, or perhaps the Yota_gradient function should convert, since its job is to make tests look the same.

@dfdx
Copy link

dfdx commented Aug 26, 2022

I fixed the most hardcore issues in the tests, but after several days of investigation I can't solve 2 remaining problems:

  • ZeroTangent vs NoTangent. Honestly, I still don't have a clear understanding of the difference. For example, function arguments may be generalized to callable structs, so it makes to return ZeroTangent() for them, yet most examples in ChainRules return NoTangent(). Another case is function ChainRules.var"#fieldtype_pullback#422, that, being applied to ZeroTangent(), returns (NoTangent(), NoTangent(), NoTangent()). So I wasn't able to adjust Yota's behavior to the tests (which reflects the behavior of Zygote, right?), but I'm open to suggestions.
  • The gradient seem to be packed and unpacked differently. For example, to account for the Tangent{Tuple} vs Tuple case above, I tried to modify Yota_gradient to this:
unpack(x::Tangent) = x.backing
unpack(x) = x
function Yota_gradient(f, xs...)
  g = Base.tail(Yota.grad(f, xs...)[2])
  return map(unpack, g)
end

It helped with some tested, but broke others. Structurally, the results seem to be correct, but I don't quite understand what needs to be adjusted - Yota, Yota_gradient or tests themselves.

I'm going to proceed with testing of Yota on Flux models + Optimisers, which should uncover more inconsistencies, but if you are want to make another pass on these tests. please try [email protected] and share your thoughts!

Copy link
Member Author

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a quick go with 0.8, and still see many errors? But will update a few things so long.

test/runtests.jl Outdated
@@ -13,6 +13,8 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
Copy link
Member Author

@mcabbott mcabbott Aug 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is a better rough translation function, much like the suggestion above:

Suggested change
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]));
y2z(::AbstractZero) = nothing # we don't care about different flavours
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t)))
y2z(x) = x

The only goal is to have as few changes as possible between tests using Zygote and the same with Yota. I don't think we care at all about the different kinds of special Zero.

Well, we care internally that all should be accepted. But when testing what's returned, we are happy if we get any one of them.

Project.toml Outdated Show resolved Hide resolved
docs/src/index.md Outdated Show resolved Hide resolved
test/destructure.jl Outdated Show resolved Hide resolved
test/destructure.jl Outdated Show resolved Hide resolved
test/destructure.jl Outdated Show resolved Hide resolved
@dfdx
Copy link

dfdx commented Aug 27, 2022

I can successfully run tests in this PR on Julia nightly with this rule added:

function rrule(::typeof(getfield), s, f::Symbol)
  y = getproperty(s, f)
  function getproperty_pullback(dy)
      dy = unthunk(dy)
      T = typeof(s)
      nt = NamedTuple{(f,)}((dy,))
      return NoTangent(), Tangent{T}(; nt...), ZeroTangent()
  end
  return y, getproperty_pullback
end

Yota contains the same rule for getproperty, which usually is enough but doesn't work in this particular case. If the code above is an acceptable solution, I can add this rule to Yota or create a PR to ChainRules.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 27, 2022

It's possible that this package and Functors.jl should think more about whether to call getfield vs getproperty. The weird @functor macro https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L11-L20 goes by fieldnames(T) and getproperty I think.

But looking at the errors on CI, maybe it's from somewhere deeper inside, involving Core.Box because everything is type-unstable?

 Error During Test at /home/runner/work/Optimisers.jl/Optimisers.jl/test/destructure.jl:205
  Test threw exception
  Expression: (Yota_gradient((x->(sum(abs2, (re9(x)).c[1]);)), 1:7))[1] == [0, 0, 0, 8, 10, 12, 14]
  No deriative rule found for op %7 = getfield(%3, :contents)::Optimisers.Restructure{NamedTuple{(:a, :b, :c), Tuple{Vector{Float64}, Matrix{Float32}, Vector{Array}}}, NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Vector{Int64}}}} , try defining it using 
  	ChainRulesCore.rrule(::typeof(getfield), ::Core.Box, ::Symbol) = ...

That said, having a rule for getfield sounds fine to me. I think it should probably call ProjectTo, since this will sometimes turn e.g. Tangent{Complex} back into a number. And perhaps care is needed about whether closing over a symbol & a type works well, or needs Val to help out?

function rrule(::typeof(getfield), x::T, f::Symbol) where T
  y = getproperty(x, f)
  proj = ProjectTo(x)
  # valT = Val(T)  # perhaps more stable inside closure?
  function getfield_pullback(dy)
      nt = NamedTuple{(f,)}((unthunk(dy),))
      # not really sure whether this ought to unthunk or not, maybe ProjectTo will anyway, in which case best to be explicit?
      return NoTangent(), proj(Tangent{T}(; nt...)), ZeroTangent()
  end
  return y, getfield_pullback
end
# These print lots in red:
@code_warntype rrule(getfield, (x=1, y=2.0), :x)
@code_warntype rrule(getfield, (x=1, y=2.0), :x)[2](3)

# But these are OK
@code_warntype (nt -> rrule(getfield, nt, :x))((x=1, y=2.0))
@code_warntype (nt -> rrule(getfield, nt, :x)[2](3.0))((x=1, y=2.0))

@mcabbott
Copy link
Member Author

It's not in the tests here, but running the Metalhead example in the docs I still get this error (with or without getfield rule, 1.8 and 1.9):

julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
                  sum(m(x))
                   end;
ERROR: BoundsError: attempt to access Nothing at index [1]
Stacktrace:
  [1] _getfield(value::Nothing, fld::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/helpers.jl:40
  [2] mkcall(::Function, ::Umlaut.Variable, ::Vararg{Any}; val::Missing, line::Nothing, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:192
  [3] mkcall
    @ ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:174 [inlined]
  [4] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:183
  [5] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:268
  [6] gradtape(::Function, ::ResNet, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:297
  [7] grad(::Function, ::ResNet, ::Vararg{Any}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:367
  [8] grad(::Function, ::ResNet, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:359

@dfdx
Copy link

dfdx commented Aug 27, 2022

But looking at the errors on CI, maybe it's from somewhere deeper inside, involving Core.Box because everything is type-unstable?

It's even curiouser! Running a random test in REPL works fine:

julia> re1 = destructure(m1)[2]
Restructure(Array, ..., 3)

julia> @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
Test Passed

But wrap it into @testset and it fails!

julia> @testset "using Yota" begin
              re1 = destructure(m1)[2]
             @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] 
          end
          
using Yota: Error During Test at REPL[69]:3
  Test threw exception
  Expression: (Yota_gradient((x->((re1(x))[1];)), rand(3)))[1] == [1, 0, 0]
  No deriative rule found for op %3 = getfield(%1, :re1)::Optimisers.Restructure{Vector{Float64}, Int64} , try defining it using 
  
        ChainRulesCore.rrule(::typeof(getfield), ::var"#95#96"{Optimisers.Restructure{Vector{Float64}, Int64}}, ::Symbol) = ...
...

Perhaps, @testset captures the module it's running in and uses something like getfield(captured_data, :global_var_name). Essentially, this is the same issue as dfdx/Yota.jl#112 .

I think it should probably call ProjectTo [...]

Yes, it makes sense. Regarding type stability, I'm going to include your definition to Yota as is for now to keep the focus on correctness, and come back to performance later.

It's not in the tests here, but running the Metalhead example in the docs I still get this error (with or without getfield rule, 1.8 and 1.9):

I'm looking at it.

@dfdx
Copy link

dfdx commented Aug 28, 2022

I may have spotted one of the bugs related to the failures on Metalhead example, but must make sure first. In this piece of code in generic broadcasting:

    ys3, backs = unzip_broadcast(args...) do a...
        rrule_via_ad(cfg, f, a...)
    end

does f refer to a function being broadcasted or to the Broadcast.broadcasted itself? For example, in this case:

f = x -> identity(x)
args = (rand(3),)
rrule(cfg, broadcasted, f, args...)

which of the following is invoked:

rrule_via_ad(cfg, broadcasted, f, args...)

or

rrule_via_ad(cfg, f, args...)

?

@mcabbott
Copy link
Member Author

I don't see an obvious mistake. The intention is for rrule(cfg, broadcasted, BroadcastStyle, sqrt, [1,2,3]) to call y, bk = rrule(cfg, sqrt, 2) i.e. acting on the elements, no broadcasting. That gives y=1.414, and then the unzip gives ys3 = [1, 1.41, 1.8] and an array of functions.

This plit_bc_pullbacks(cfg, f, args...) is never passed the broadcasted, BroadcastStyle parts, since it doesn't need them. So its second argument should be sqrt.

@dfdx
Copy link

dfdx commented Aug 28, 2022

Oh, I don't think it's a mistake in the generic broadcasting, but rather in Yota.rrule_via_ad()! The example I'm currently testing is this:

using Flux, Yota

model = Dense(28*28, 1024, x -> identity(x))
x = rand(Float32, 28*28, 4)
grad((model, x) -> sum(model(x)), model, x)

which produces this nice stacktrace:

ERROR: all field arrays must have same shape
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::StructArrays.var"#6#7"{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})(ci::Vector{Function})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:21
  [3] map
    @ ./tuple.jl:221 [inlined]
  [4] (StructArrays.StructArray{Tuple{Float32, Function}, 2, Tuple{Matrix{Float32}, Vector{Function}}})(c::Tuple{Matrix{Float32}, Vector{Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:20
  [5] (StructArrays.StructArray{Tuple{Float32, Function}})(c::Tuple{Matrix{Float32}, Vector{Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:97
  [6] _widenstructarray(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:118
  [7] widen_from_type(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, var"#24#26"}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:109
  [8] widen_from_instance(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, i::Int64, el::Tuple{Float32, var"#24#26"})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:105
  [9] collect_to_structarray!(dest::StructArrays.StructArray{Tuple{Float32, var"#25#27"}, 2, Tuple{Matrix{Float32}, Matrix{var"#25#27"}}, Int64}, itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, offs::Int64, st::Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:77
 [10] _collect_structarray!
    @ ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:59 [inlined]
 [11] _collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, elem::Tuple{Tuple{Float32, var"#25#27"}, Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}}}, ax::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:54
 [12] collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:40
 [13] StructArrays.StructArray(v::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1705#1707"{YotaRuleConfig, var"#141#142"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; unwrap::typeof(StructArrays.alwaysfalse))
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:261
 [14] StructArray
    @ ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:260 [inlined]
 [15] unzip_broadcast
    @ ~/.julia/packages/ChainRules/DUopG/src/unzipped.jl:39 [inlined]
 [16] split_bc_pullbacks(cfg::YotaRuleConfig, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:127
 [17] rrule(cfg::YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:44
 [18] mkcall(::Function, ::YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:192
 [19] chainrules_transform!(tape::Tape{GradCtx})
    @ Main ~/work/Yota/src/grad.jl:184
 [20] gradtape!(tape::Tape{GradCtx}; seed::Int64)
    @ Main ~/work/Yota/src/grad.jl:271
 [21] gradtape(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; ctx::GradCtx, seed::Int64)
    @ Main ~/work/Yota/src/grad.jl:300
 [22] grad(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; seed::Int64)
    @ Main ~/work/Yota/src/grad.jl:370
 [23] grad(::Function, ::Dense{var"#141#142", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any})
    @ Main ~/work/Yota/src/grad.jl:362
 [24] top-level scope
    @ REPL[26]:1

From the stacktrace I infer that rrule_via_ad() returns not what unzip_broadcast() expects. I made a guess that split_bc_pullbacks() calls rrule_via_ad() on broadcasted itself, e.g.:

julia> y, bk = rrule_via_ad(YotaRuleConfig(), broadcasted, sqrt, [1.0, 2, 3])
...
julia> y
3-element Vector{Float64}:
 1.0
 1.4142135623730951
 1.7320508075688772

julia> bk
#24 (generic function with 1 method)

and that unzip_broadcast() expects y and bk to have the same length to pack them into StructArray. But if you say rrule_via_ad() is never invoked that way, then I'm going to get a good night's sleep before the next iteration of debugging 😄

@mcabbott
Copy link
Member Author

Quite the stacktrace! These lines look correct to me: The same function f is being passed through, and it acts on arg which is the result of lazy broadcasting +. The broadcasted, DefaultArrayStyle arguments are marked unused:

 [16] split_bc_pullbacks(cfg::YotaRuleConfig, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:127
 [17] rrule(cfg::YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#141#142", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:44

To get to line 39 https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/unzipped.jl#L39 the function rrule_via_ad must have inferred to give a Tuple which is correct. I don't see how the shape can then come out wrong, but...

@dfdx
Copy link

dfdx commented Sep 3, 2022

Here's an interesting observation. If I run the same example as is:

using Flux, Yota, ChainRules


myid = x -> identity(x)
model = Dense(5, 3, myid)
x = rand(Float32, 5, 1);
val, g = grad((model, x) -> sum(model(x)), model, x)
@show val
@show g

I get the same stacktrace as posted above, complaining about "ERROR: all field arrays must have same shape". However, if I slightly modify unzip_broadcast() and just add Broadcast.materialize(bc):

function unzip_broadcast(f::F, args...) where {F}
    T = Broadcast.combine_eltypes(f, args)
    if isconcretetype(T)
        T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
            but f = $(sprint(show, f)) returns type T = $T"""))
    end
    bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
    bcs = Broadcast.BroadcastStyle(typeof(bc))
    if bcs isa AbstractGPUArrayStyle
        # This is a crude way to allow GPU arrays, not currently tested, TODO.
        # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
        return unzip(broadcast(f, args...))
    elseif bcs isa Broadcast.AbstractArrayStyle
        Broadcast.materialize(bc)                                             # <-- this line added
        return StructArrays.components(StructArray(bc))
    else
        return unzip(broadcast(f, args...))  # e.g. tuples
    end
    # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
end

The error disappears!

The only hypothesis I have is that materialization of a broadcasted variable changes something in the global Julia state that makes it more friendly to StructArray, but I can't find any relevant information.


(ChainRules) pkg> st
Project ChainRules v1.44.5
Status `~/work/ChainRules.jl/Project.toml`
  [79e6a3ab] Adapt v3.4.0
  [d360d2e6] ChainRulesCore v1.15.3
  [34da2185] Compat v4.2.0
  [46192b85] GPUArraysCore v0.1.2
  [92d709cd] IrrationalConstants v0.1.1
  [c1ae055f] RealDot v0.1.0
  [09ab397b] StructArrays v0.6.12
  [cd998857] Yota v0.8.0 `https://github.com/dfdx/Yota.jl.git#fix-broadcast`
  [8ba89e20] Distributed
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
  [2f01184e] SparseArrays
  [10745b16] Statistics

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

That is pretty odd.

I can reproduce this, by @eval ChainRules function unzip_broadcast(f::F, args...) where {F} ... your definition in the repl. What's strange is that if I then @eval again the old code (or the old code with some printout) it still works. (The signature has not changed, the replacement code is run.) Does this mean it's some world-age problem or something?

Edit: I've pasted in a complete session below. This @eval has exactly the same code as the source, and somehow fixes the problem. Running it before grad seems to have no effect.

julia> using Flux, Yota, ChainRules

julia> ENV["JULIA_DEBUG"] = ChainRules;

julia> begin
        myid = x -> identity(x)
        model = Dense(5, 3, myid)
        x = rand(Float32, 5, 1)
       end;

julia> val, g = grad((model, x) -> sum(model(x)), model, x)
┌ Debug: broadcasting: plus
│   length(xs) = 2
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:161
┌ Debug: split broadcasting generic
│   f = #7 (generic function with 1 method)
│   N = 1
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:126
ERROR: all field arrays must have same shape
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::StructArrays.var"#6#7"{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})(ci::Vector{Function})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:21
  [3] map
    @ ./tuple.jl:273 [inlined]
  [4] (StructArrays.StructArray{Tuple{Float32, Function}, 2, Tuple{Matrix{Float32}, Vector{Function}}})(c::Tuple{Matrix{Float32}, Vector{Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:20
  [5] (StructArrays.StructArray{Tuple{Float32, Function}})(c::Tuple{Matrix{Float32}, Vector{Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:97
  [6] _widenstructarray(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Function}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:118
  [7] widen_from_type(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, #unused#::Type{Tuple{Float32, Yota.var"#20#22"}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:109
  [8] widen_from_instance(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, i::Int64, el::Tuple{Float32, Yota.var"#20#22"})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:105
  [9] collect_to_structarray!(dest::StructArrays.StructArray{Tuple{Float32, Yota.var"#21#23"}, 2, Tuple{Matrix{Float32}, Matrix{Yota.var"#21#23"}}, Int64}, itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, offs::Int64, st::Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:77
 [10] _collect_structarray!
    @ ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:59 [inlined]
 [11] _collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}, elem::Tuple{Tuple{Float32, Yota.var"#21#23"}, Tuple{CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CartesianIndex{2}}}, ax::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:54
 [12] collect_structarray(itr::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/collect.jl:40
 [13] StructArrays.StructArray(v::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, ChainRules.var"#1707#1709"{Yota.YotaRuleConfig, var"#7#8"}, Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}}}}; unwrap::typeof(StructArrays.alwaysfalse))
    @ StructArrays ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:261
 [14] StructArray
    @ ~/.julia/packages/StructArrays/w2GaP/src/structarray.jl:260 [inlined]
 [15] unzip_broadcast
    @ ~/.julia/packages/ChainRules/fgVxV/src/unzipped.jl:39 [inlined]
 [16] split_bc_pullbacks(cfg::Yota.YotaRuleConfig, f::var"#7#8", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:127
 [17] rrule(cfg::Yota.YotaRuleConfig, #unused#::typeof(Base.Broadcast.broadcasted), #unused#::Base.Broadcast.DefaultArrayStyle{2}, f::var"#7#8", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})
    @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:44
 [18] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/vGy3v/src/tape.jl:192
 [19] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:181
 [20] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:268
 [21] gradtape(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:297
 [22] grad(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:367
 [23] grad(::Function, ::Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:359
 [24] top-level scope
    @ REPL[4]:1
 [25] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

julia> @eval ChainRules function unzip_broadcast(f::F, args...) where {F}
           T = Broadcast.combine_eltypes(f, args)
           if isconcretetype(T)
               T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
                   but f = $(sprint(show, f)) returns type T = $T"""))
           end
           bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
           bcs = Broadcast.BroadcastStyle(typeof(bc))
           if bcs isa AbstractGPUArrayStyle
               # This is a crude way to allow GPU arrays, not currently tested, TODO.
               # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
               return unzip(broadcast(f, args...))
           elseif bcs isa Broadcast.AbstractArrayStyle
               # Broadcast.materialize(bc)   # <-- this line added  # <-- now removed, identical to original
               return StructArrays.components(StructArray(bc))
           else
               return unzip(broadcast(f, args...))  # e.g. tuples
           end
           # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
       end
unzip_broadcast (generic function with 1 method)

julia> val, g = grad((model, x) -> sum(model(x)), model, x)
┌ Debug: broadcasting: plus
│   length(xs) = 2
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:161
┌ Debug: split broadcasting generic
│   f = #7 (generic function with 1 method)
│   N = 1
└ @ ChainRules ~/.julia/packages/ChainRules/fgVxV/src/rulesets/Base/broadcast.jl:126
(2.6951299f0, (ChainRulesCore.ZeroTangent(), Tangent{Dense{var"#7#8", Matrix{Float32}, Vector{Float32}}}= ChainRulesCore.ZeroTangent(), weight = Float32[0.88211715 0.71158904  0.74754727 0.49648; 0.88211715 0.71158904  0.74754727 0.49648; 0.88211715 0.71158904  0.74754727 0.49648], bias = Float32[1.0, 1.0, 1.0]), Float32[1.0559639; 1.8083295;  ; 0.78016365; -0.7226729;;]))

@dfdx
Copy link

dfdx commented Sep 4, 2022

Here's a hypothesis for world age problem:

  • bc contains a reference to a lazy function definition
  • materialize(bc) triggers the definition and adds a new method to an existing method table
  • StructArray(bc) w/ and w/o prior call to materialize(bc) thus goes different dispatch paths and hits different versions of the same function

But:

  • if I put Base.get_world_counter() before and after materialize(bc), I see the same world number
  • I don't see a function to which a new method is added, but only completely new functions that shouldn't change the dispatch path

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

How are you adding this materialize line? By editing the source used for a fresh session, or by loading something while running?

@dfdx
Copy link

dfdx commented Sep 4, 2022

I have a file called _main.jl inside Yota/src directory with contents similar to this:

include("core.jl")       # in its turn, core.jl includes all the files from Yota, so now Main ~ Yota

using Flux

# I think these imports are not needed anymore, but just copy pasting them 
import ChainRules: unzip_broadcast, RCR, TRI_NO, AbstractGPUArrayStyle, StructArrays
import ChainRules.StructArrays: StructArray


@eval ChainRules function unzip_broadcast(f::F, args...) where {F}
    global BC_STATE = (f, args)
    T = Broadcast.combine_eltypes(f, args)
    if isconcretetype(T)
        T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
            but f = $(sprint(show, f)) returns type T = $T"""))
    end
    # bc - rrule_via_ad's wrapper broadcasted to all arguments
    bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
    bcs = Broadcast.BroadcastStyle(typeof(bc))
    if bcs isa AbstractGPUArrayStyle
        # This is a crude way to allow GPU arrays, not currently tested, TODO.
        # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
        return unzip(broadcast(f, args...))
    elseif bcs isa Broadcast.AbstractArrayStyle
        println("World age before materialize(bc): $(Base.get_world_counter())")
        # Broadcast.materialize(bc)
        println("World age after materialize(bc): $(Base.get_world_counter())")
        # global BC = bc
        return StructArrays.components(StructArray(bc))
    else
        return unzip(broadcast(f, args...))  # e.g. tuples
    end
    # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
end


function bc_test()
    myid = x -> identity(x)
    model = Dense(5, 3, myid)
    x = rand(Float32, 5, 1);
    grad((model, x) -> sum(model(x)), model, x)
end

Whenever I do a change, I include the whole file, thus updating all definitions from Yota + ChainRules.unzip_broadcast.


I also noticed that the problem is fixed if I replace rrule_via_ad() with a dummy implementation that doesn't generate new functions:

function ChainRulesCore.rrule_via_ad(cfg::YotaRuleConfig, f, args...)
    return 1.0, dy -> (ZeroTangent(), [ZeroTangent() for _ in args]...)
end

In theory, I can make rrule_via_ad() always work as interpreter and never compile code, but it would be a high price...

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

Focusing on these lines

  [4] (StructArrays.StructArray{Tuple{Float32, Function}, 2, Tuple{Matrix{Float32}, Vector{Function}}})(c::Tuple{Matrix{Float32}, Vector{Function}})

 [16] split_bc_pullbacks(cfg::Yota.YotaRuleConfig, f::var"#7#8", args::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float32}, Vector{Float32}}})

here's a smaller reproducer:

julia> using ChainRules, Yota

julia> y, bk = ChainRules.split_bc_pullbacks(Yota.YotaRuleConfig(), identity, Broadcast.broadcasted(+, [1 2; 3 4], [5, 6]));

julia> bk([7 8; 9 0])  # with identity it works fine, also sqrt
(ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), [7 8; 9 0])

julia> y, bk = ChainRules.split_bc_pullbacks(Yota.YotaRuleConfig(), x -> identity(x), Broadcast.broadcasted(+, [1 2; 3 4], [5, 6]));
ERROR: all field arrays must have same shape

(@v1.9) pkg> st Yota ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
  [082447d4] ChainRules v1.44.5
  [cd998857] Yota v0.8.0

This seems to avoid my order-of-loading weirdness above. If I @eval your method with materialize, then it starts working. If I @eval again the old one without, it stops working.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 4, 2022

Some half-way steps:

julia> using ChainRules, Yota

# Easy case

julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), identity, x)
       end
2×2 Matrix{Tuple{Int64, ChainRules.var"#identity_pullback#1201"}}:
 (6, identity_pullback)  (7, identity_pullback)
 (9, identity_pullback)  (10, identity_pullback)

julia> ChainRules.unzip(ans)
([6 7; 9 10], [ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"(); ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"()])

julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.NoTangent, Int64}}:
 (NoTangent(), 7)  (NoTangent(), 8)
 (NoTangent(), 9)  (NoTangent(), 0)

julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), identity, x)
       end
([6 7; 9 10], [ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"(); ChainRules.var"#identity_pullback#1201"() ChainRules.var"#identity_pullback#1201"()])

julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.NoTangent, Int64}}:
 (NoTangent(), 7)  (NoTangent(), 8)
 (NoTangent(), 9)  (NoTangent(), 0)

# Now try with y -> identity(y)

julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), y -> identity(y), x)
       end
2×2 Matrix{Tuple{Int64, Function}}:  ## <-- notice Function, abstract type
 (6, #21)  (7, #20)
 (9, #20)  (10, #20)

julia> ChainRules.unzip(ans)  ## notice Core.Box
([6 7; 9 10], Function[Yota.var"#21#23"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))); Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_#72#328#86"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"())))])

julia> broadcast(|>, [7 8; 9 0], ans[2])
2×2 Matrix{Tuple{ChainRulesCore.ZeroTangent, Int64}}:
 (ZeroTangent(), 7)  (ZeroTangent(), 8)
 (ZeroTangent(), 9)  (ZeroTangent(), 0)

julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), y -> identity(y), x)
       end
ERROR: all field arrays must have same shape

# Name the function:

julia> myid(x) = x;

julia> broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x)
       end
2×2 Matrix{Tuple{Int64, Function}}:  ## <-- looks as bad
 (6, #21)  (7, #20)
 (9, #20)  (10, #20)

julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
         ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x)   ## now this works, with Core.Box
       end
([6 7; 9 10], Yota.var"#20#22"[Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))); Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"()))) Yota.var"#20#22"(Core.Box(Yota.var"##pullback_myid#334#89"{ChainRules.var"#identity_pullback#1201"}(ChainRules.var"#identity_pullback#1201"())))])

@dfdx
Copy link

dfdx commented Sep 4, 2022

Apparently, in the last example there's no error because rrule() for myid has been already compiled in the previous broadcast() call. Running only the last statement gives the same error:

julia> ChainRules.unzip_broadcast(Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])) do x
        ChainRules.rrule_via_ad(Yota.YotaRuleConfig(), myid, x)
end
...
ERROR: all field arrays must have same shape
...

@dfdx
Copy link

dfdx commented Sep 5, 2022

My current understanding is as follows:

  • Yota.rrule_via_ad() generates a new rrule() and shifts the world age forward
  • unzip_broadcast() doesn't immediately call rrule_via_ad(), but instead creates a lazy broadcasting object bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)), where f is wrapper created here:
unzip_broadcast(args...) do a...
    rrule_via_ad(cfg, f, a...)
end
  • this broadcasted object is actually materialized only when StructArray(bc) is called
  • when rrule_via_ad() itself acts on broadcasted objects, combination of double broadcasting, new function generation and something in StructArray constructor leads to incorrect results

Removing any of these factors solves the problem. Also, if in unzip_broadcast() I replace:

bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))

with this:

bc = broadcast(f, args...)

it leads to earlier evaluation and fixes the issue to. Given that a few lines later, in all 3 branches we materialize bc anyway:

        if bcs isa AbstractGPUArrayStyle
            # This is a crude way to allow GPU arrays, not currently tested, TODO.
            # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
            return unzip(broadcast(f, args...))
        elseif bcs isa Broadcast.AbstractArrayStyle        
            return StructArrays.components(StructArray(bc))
        else
            return unzip(broadcast(f, args...))  # e.g. tuples
        end

I wonder why do we need bc to be lazily broadcasted here at all?

@mcabbott
Copy link
Member Author

mcabbott commented Sep 5, 2022

in all 3 branches we materialize bc anyway:

But one of them materialises directly two arrays, instead of allocating an array of tuples first. This path is the entire reason for this function, and for depending on StructArrays.

Cc @piever in case this weird error rings any bells. (I wonder if it's possible to hit it without AD being involved?)

@piever
Copy link

piever commented Sep 5, 2022

Not sure if this is helpful, but here are some thoughts that could be useful.

  1. One of the reasons for the mismatches could be that for StructArrays with no fields, the size of the array is hard to preserve (could be useful to test this branch with throw if no fields JuliaArrays/StructArrays.jl#235, though based on the eval & co. issues I imagine this shouldn't be the problem).
  2. StructArrays collection mechanism only uses inference when the iterator is empty.
  3. There is a PR to improve broadcast for StructArrays: Generalize StructArray's broadcast. JuliaArrays/StructArrays.jl#215, maybe it is helpful here (it would unify GPU and CPU implementation of unzip_broadcast ideally).

Though I definitely am puzzled as to why this is happening. Looks like the collection mechanism StructArray(bc) is failing on a lazy broadcasted object (indeed this is a different code path than StructArray(::AbstractArray)). An "AD-free" reproducer would definitely help narrow this down.

@dfdx
Copy link

dfdx commented Sep 5, 2022

Here's a reproducible example without Yota and ChainRules:

import StructArrays
import StructArrays: StructArray

# eval a new function similar to rrule()
function make_rrule(f, args...)
    name = gensym()
    ex = :(function $name(f, args...)
        y = sum(args)
        pullback(dy) = dy + y
        return y, pullback
    end)
    return Base.eval(@__MODULE__, ex)
end

# wrap rrule-like function with required number of invokelatest()
function rrule_via_ad(f, args...)
    rr = make_rrule(f, args...)
    res = Base.invokelatest(rr, f, args...)
    y, pb_ = res
    pb = dy -> Base.invokelatest(pb_, dy)
    return y, pb
end

# original split_bc_pullbacks stripped to the bones
function split_bc_pullbacks(f::F, args::Vararg{Any,N}) where {F,N}
    wf = (a...) -> rrule_via_ad(f, a...)
    # comment/uncomment the next 2 lines to make the example fail/work
    bc = Broadcast.instantiate(Broadcast.broadcasted(wf, args...))
    # bc = broadcast(wf, args...)
    return StructArrays.components(StructArray(bc))
end

bce() = Broadcast.broadcasted(+, [1 2; 3 4], [5, 6])
split_bc_pullbacks(x -> identity(x), bce())

Since make_rrule() generates a new function on every call, this code can be run with fail in the same REPL multiple times. I think the pullback stuff can also be removed, but I want to try the aforementioned branches from StructArrays first.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 5, 2022

Stil happens on the PR's branch. You can simplify a bit further, and note that acting on a vector is OK, but higher ndims fails:

julia> [4,5,6] .|> split_bc_pullbacks(x -> identity(x), [1,2,3])[2]
3-element Vector{Int64}:
 5
 7
 9

julia> split_bc_pullbacks(x -> identity(x), [1 2; 3 4])[2]
ERROR: all field arrays must have same shape

Trying to pick bits out of the stack trace, is this correct?

julia> mat = [(; i, f) for i in 1:3, f in (sin, sin)] |> StructArray
3×2 StructArray(::Matrix{Int64}, ::Matrix{typeof(sin)}) with eltype NamedTuple{(:i, :f), Tuple{Int64, typeof(sin)}}:
 (i = 1, f = sin)  (i = 1, f = sin)
 (i = 2, f = sin)  (i = 2, f = sin)
 (i = 3, f = sin)  (i = 3, f = sin)

julia> StructArrays._widenstructarray(mat, 2, Tuple{Int, Function})
6-element StructArray(::Vector{Int64}, ::Vector{Function}) with eltype Tuple{Int64, Function}:
    (1, sin)
 #undef
 #undef
 #undef
 #undef
 #undef

@piever
Copy link

piever commented Sep 6, 2022

Trying to pick bits out of the stack trace, is this correct?

julia> mat = [(; i, f) for i in 1:3, f in (sin, sin)] |> StructArray
3×2 StructArray(::Matrix{Int64}, ::Matrix{typeof(sin)}) with eltype NamedTuple{(:i, :f), Tuple{Int64, typeof(sin)}}:
 (i = 1, f = sin)  (i = 1, f = sin)
 (i = 2, f = sin)  (i = 2, f = sin)
 (i = 3, f = sin)  (i = 3, f = sin)

julia> StructArrays._widenstructarray(mat, 2, Tuple{Int, Function})
6-element StructArray(::Vector{Int64}, ::Vector{Function}) with eltype Tuple{Int64, Function}:
    (1, sin)
 #undef
 #undef
 #undef
 #undef
 #undef

Agh, no it isn't, well spotted! Somehow the widening mechanism was not updated to support arrays of arbitrary shape and only worked for 2D things... JuliaArrays/StructArrays.jl#246 should hopefully fix it!

@dfdx
Copy link

dfdx commented Sep 8, 2022

I can confirm JuliaArrays/StructArrays.jl#246 fixes all the issues up to my first reproducer using Flux and Yota. Thanks for the quick fix!

The Metalhead example still fails though, but that's another story, which I'm looking at now.


loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
end;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased this and tests pass!

This example does not, it fails with a seemingly simple error:

julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
          sum(m(x))
        end;


loss, (_, ∇model) = Yota.grad(m -> sum(m(image)), model)ERROR: No derivative rule found for op %454 = ntuple(%452, 4)::NTuple{4, Int64} , try defining it using 

	ChainRulesCore.rrule(::typeof(ntuple), ::Flux.var"#336#337"{4, Array{Float32, 4}}, ::Int64) = ...

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/packages/Yota/KJQ6n/src/grad.jl:219

That was on tagged Yota; on latest everything instead it seems to take forever, and interrupts here:

julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
          sum(m(x))
        end;

^CERROR: InterruptException:
Stacktrace:
   [1] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:792
   [2] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
     @ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
   [3] #68
     @ ./none:0 [inlined]
   [4] iterate
     @ ./generator.jl:47 [inlined]
   [5] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:787
   [6] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
     @ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
   [7] #68
     @ ./array.jl:0 [inlined]
   [8] iterate
     @ ./generator.jl:47 [inlined]
   [9] collect_to!(dest::Vector{Vector{Umlaut.Variable}}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, offs::Int64, st::Int64)
     @ Base ./array.jl:845
  [10] collect_to_with_first!(dest::Vector{Vector{Umlaut.Variable}}, v1::Vector{Umlaut.Variable}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, st::Int64)
     @ Base ./array.jl:823
  [11] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:797
--- the last 10 lines are repeated 2 more times ---

(jl_aZPcXz) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_aZPcXz/Project.toml`
  [dbeba491] Metalhead v0.8.0-DEV `https://github.com/FluxML/Metalhead.jl.git#master`
  [3bd65402] Optimisers v0.2.10 `~/.julia/dev/Optimisers`
  [09ab397b] StructArrays v0.6.13 `https://github.com/JuliaArrays/StructArrays.jl.git#master`
  [cd998857] Yota v0.8.1 `https://github.com/dfdx/Yota.jl.git#main`

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was indeed investigating incredibly long processing time, but profiler blamed type inference/abstract interpreter, so I started a long search for a better way to trace functions (e.g. see my recent post on Discourse). However, your stacktrace implies the problem may actually appear after the tracing. I will try to investigate this option too closer to the end of the week.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I opened an issue to track this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The ResNet(18) example now compiles and runs in 61 second (compared to 47 seconds with Zygote). Subsequent calls take ~0.4 seconds on my CPU.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I see something similar locally, on 0.8.2

@mcabbott mcabbott marked this pull request as ready for review October 31, 2022 00:27
@ToucheSir ToucheSir closed this Nov 3, 2022
@ToucheSir ToucheSir reopened this Nov 3, 2022
@ToucheSir
Copy link
Member

Are the failures on nightly easy to resolve?

@dfdx
Copy link

dfdx commented Nov 3, 2022

It's a failure in CompilerPluginTools.jl, which apparently has not been adapted for Julia 1.9 yet. I opened JuliaCompilerPlugins/CompilerPluginTools.jl#8 to track it.

@mcabbott
Copy link
Member Author

Should we just skip tests on nightly, so that this can go in?

@dfdx do you know whether 1.9 works?

@dfdx
Copy link

dfdx commented Nov 28, 2022

It looks like there's more work to do in CompilerPluginTools.jl to make it work on Julia 1.9, so I don't think it will happen in the nearest time. If we can skip Yota tests for Julia 1.9, it should be the most efficient solution for now.

Note that Julia nightly now points to Julia 1.10, so perhaps we need a separate entry for the 1.9.

@mcabbott
Copy link
Member Author

mcabbott commented Dec 8, 2022

Tests with Yota are now skipped for 1.9 & up.

Should be ready to go. Can someone approve?

@mcabbott mcabbott merged commit 79269be into FluxML:master Dec 8, 2022
@mcabbott mcabbott deleted the yota branch December 8, 2022 03:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate using a different AD for tests
4 participants