Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace @adjoint with rrule #1863

Merged
merged 4 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.13.0-DEV"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Adapt = "3.0"
ArrayInterface = "3.1, 4"
CUDA = "3"
ChainRulesCore = "1.12"
Functors = "0.2.1"
MLUtils = "0.1.4"
MacroTools = "0.5"
Expand All @@ -35,7 +37,7 @@ ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Zygote = "0.6"
Zygote = "0.6.34"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using MLUtils

using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
using ChainRulesCore

export Chain, Dense, Maxout, SkipConnection, Parallel,
RNN, LSTM, GRU, GRUv3,
Expand Down
5 changes: 3 additions & 2 deletions src/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module CUDAint
using ..CUDA

import ..Flux: Flux
import Zygote
using Zygote: @adjoint
# import Zygote
# using Zygote: @adjoint
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
using ChainRulesCore
import NNlib, NNlibCUDA

include("cudnn.jl")
Expand Down
5 changes: 3 additions & 2 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
training=Flux._isactive(BN)))
end

@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
end
y, batchnorm_pullback
end
7 changes: 7 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,

# v0.13 deprecations

function Broadcast.broadcasted(f::Recur, args...)
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
Re-writing this as a comprehension would be better.""", :broadcasted)
map(f, args...) # map isn't really safe either, but
end

@deprecate frequencies(xs) group_counts(xs)

# Channel notation: Changed to match Conv, but very softly deprecated!
Expand Down
10 changes: 5 additions & 5 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x

Zygote.@adjoint function Array(x::CUDA.CuArray)
Array(x), d -> (CUDA.cu(d),)
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
Array(x), d -> (NoTangent(), CUDA.cu(d),)
end

Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),)
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
end

# CPU/GPU movement conveniences
Expand Down Expand Up @@ -204,7 +204,7 @@ function check_use_cuda()
end
end
end
Zygote.@nograd check_use_cuda
ChainRulesCore.@non_differentiable check_use_cuda()

# Precision

Expand Down
3 changes: 1 addition & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
)
end

# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
@nograd conv_transpose_dims
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
Expand Down
13 changes: 4 additions & 9 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
istraining() = false

@adjoint istraining() = true, _ -> nothing
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised there isn't an equivalent for this in ChainRules already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere I was writing a function like CRC.order().back > 0... would be good to have.


_isactive(m) = isnothing(m.active) ? istraining() : m.active

Expand Down Expand Up @@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x, Δ -> (Δ, nothing)
y = dropout_mask(rng, x, p, dims=dims)
return x .* y, Δ -> (nothing, Δ .* y, nothing)
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
Expand All @@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:)
end

# TODO move this to NNlib
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p)
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = rng_from_array())
Expand Down Expand Up @@ -234,7 +228,8 @@ function _track_stats!(
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
return nothing
end
Zygote.@nograd _track_stats!

ChainRulesCore.@non_differentiable _track_stats!(::Any...)

"""
BatchNorm(channels::Integer, λ=identity;
Expand Down
16 changes: 6 additions & 10 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)

@adjoint function multigate(x::AbstractArray, h, c)
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
function multigate_pullback(dy)
dx = Zygote._zero(x, eltype(x))
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
dyᵢ isa AbstractZero && return
@. dxᵢ += dyᵢ
end
return (dx, nothing, nothing)
return (NoTangent(), dx, NoTangent(), NoTangent())
end
return multigate(x, h, c), multigate_pullback
end
Expand Down Expand Up @@ -379,8 +380,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)


@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
1 change: 1 addition & 0 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Losses
using Statistics
using Zygote
using Zygote: @adjoint
using ChainRulesCore
using ..Flux: ofeltype, epseltype
using CUDA
using NNlib: logsoftmax, logσ
Expand Down
8 changes: 4 additions & 4 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ for mathematical details.
"""
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

@adjoint function ctc_loss(ŷ, y)
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
tmp = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
return tmp.loss, ctc_loss_pullback
end


Expand Down
5 changes: 4 additions & 1 deletion src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ end
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end

ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't literally translate broadcasted(::typeof(xlogy) rule to a Zygote-free world, as unbroadcast (which sums as necessary for mismatched shapes) belongs to Zygote.

I hope that Diffractor's broadcasting will work via @scalar_rule. But the rule as written is slightly different, as it doesn't treat Δ==0 as a strong zero, when y==0. Does that matter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux could switch to those. It has branches not ifelse, and different NaN behaviour, not sure if that matters:

https://github.com/JuliaStats/LogExpFunctions.jl/blob/584442d9bd4c4abadfb5daed86cefa5fabfff645/src/basicfuns.jl#L17-L30

And 5 dependencies.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for now perhaps it's evidence that the scalar rules are ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you looking to do some testing soon with this and Diffractor/not Zygote? Otherwise I think it would be cleaner to have a separate PR that removes all of the code above in favour of https://github.com/FluxML/Zygote.jl/blob/master/src/lib/logexpfunctions.jl and the @scalar_rules in LogExpFunctions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove these rules for now if you prefer. The functions ought to be differentiable without special rules, mostly. The PR just wants to translate as many things as possible over for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I said:

as unbroadcast (which sums as necessary for mismatched shapes)

This is wrong, because _check_sizes demands equal size, simplifying the broadcast:

https://github.com/FluxML/Flux.jl/blob/master/src/losses/utils.jl#L27

While I guess these broadcasts aren't so performance-sensitive (since there will only be one, for the whole model) it would be nice if all loss functions were all second-differentiable. Whether that already works, or needs to be done by fiddling with broadcasting, or rules for the loss functions themselves, I don't know.


# This can be made an error in Flux v0.13, for now just a warning
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
for d in 1:max(ndims(ŷ), ndims(y))
Expand All @@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1

Zygote.@nograd _check_sizes
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
6 changes: 5 additions & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ function _fast_argmax(x::OneHotLike)
end
end

@nograd OneHotArray, onecold, onehot, onehotbatch
ChainRulesCore.@non_differentiable onehot(::Any...)
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
ChainRulesCore.@non_differentiable onecold(::Any...)

ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)

function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
Expand Down
7 changes: 5 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ function _restructure(m, xs)
return m̄
end

@adjoint function _restructure(m, xs)
@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
Expand Down Expand Up @@ -603,7 +603,10 @@ true
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

@nograd modules
@nograd modules # TODO: is this correct? might fail with explicit parameters.
function ChainRulesCore.rrule(::typeof(modules), m)
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
end

isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
Expand Down