Skip to content

Commit

Permalink
Replace @adjoint with rrule (#1863)
Browse files Browse the repository at this point in the history
* replace at-adjoint with rrule

* fixup

* onecold was missing

* rm comment

Co-authored-by: Brian Chen <[email protected]>

Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
mcabbott and ToucheSir authored Feb 24, 2022
1 parent 57ef5c0 commit 525b645
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 39 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.13.0-DEV"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Adapt = "3.0"
ArrayInterface = "3.1, 4"
CUDA = "3"
ChainRulesCore = "1.12"
Functors = "0.2.1"
MLUtils = "0.1.4"
MacroTools = "0.5"
Expand All @@ -35,7 +37,7 @@ ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Zygote = "0.6"
Zygote = "0.6.34"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using MLUtils

using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
using ChainRulesCore

export Chain, Dense, Maxout, SkipConnection, Parallel,
RNN, LSTM, GRU, GRUv3,
Expand Down
3 changes: 1 addition & 2 deletions src/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module CUDAint
using ..CUDA

import ..Flux: Flux
import Zygote
using Zygote: @adjoint
using ChainRulesCore
import NNlib, NNlibCUDA

include("cudnn.jl")
Expand Down
5 changes: 3 additions & 2 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
training=Flux._isactive(BN)))
end

@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
end
y, batchnorm_pullback
end
7 changes: 7 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,

# v0.13 deprecations

function Broadcast.broadcasted(f::Recur, args...)
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
Re-writing this as a comprehension would be better.""", :broadcasted)
map(f, args...) # map isn't really safe either, but
end

@deprecate frequencies(xs) group_counts(xs)

# Channel notation: Changed to match Conv, but very softly deprecated!
Expand Down
10 changes: 5 additions & 5 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x

Zygote.@adjoint function Array(x::CUDA.CuArray)
Array(x), d -> (CUDA.cu(d),)
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
Array(x), d -> (NoTangent(), CUDA.cu(d),)
end

Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),)
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
end

# CPU/GPU movement conveniences
Expand Down Expand Up @@ -204,7 +204,7 @@ function check_use_cuda()
end
end
end
Zygote.@nograd check_use_cuda
ChainRulesCore.@non_differentiable check_use_cuda()

# Precision

Expand Down
3 changes: 1 addition & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
)
end

# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
@nograd conv_transpose_dims
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
Expand Down
13 changes: 4 additions & 9 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
istraining() = false

@adjoint istraining() = true, _ -> nothing
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

_isactive(m) = isnothing(m.active) ? istraining() : m.active

Expand Down Expand Up @@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x, Δ -> (Δ, nothing)
y = dropout_mask(rng, x, p, dims=dims)
return x .* y, Δ -> (nothing, Δ .* y, nothing)
end

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
Expand All @@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:)
end

# TODO move this to NNlib
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p)
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = rng_from_array())
Expand Down Expand Up @@ -234,7 +228,8 @@ function _track_stats!(
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
return nothing
end
Zygote.@nograd _track_stats!

ChainRulesCore.@non_differentiable _track_stats!(::Any...)

"""
BatchNorm(channels::Integer, λ=identity;
Expand Down
16 changes: 6 additions & 10 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)

@adjoint function multigate(x::AbstractArray, h, c)
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
function multigate_pullback(dy)
dx = Zygote._zero(x, eltype(x))
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
dyᵢ isa AbstractZero && return
@. dxᵢ += dyᵢ
end
return (dx, nothing, nothing)
return (NoTangent(), dx, NoTangent(), NoTangent())
end
return multigate(x, h, c), multigate_pullback
end
Expand Down Expand Up @@ -379,8 +380,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)


@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
1 change: 1 addition & 0 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Losses
using Statistics
using Zygote
using Zygote: @adjoint
using ChainRulesCore
using ..Flux: ofeltype, epseltype
using CUDA
using NNlib: logsoftmax, logσ
Expand Down
8 changes: 4 additions & 4 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ for mathematical details.
"""
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

@adjoint function ctc_loss(ŷ, y)
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
tmp = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
return tmp.loss, ctc_loss_pullback
end


Expand Down
5 changes: 4 additions & 1 deletion src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ end
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end

ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)

# This can be made an error in Flux v0.13, for now just a warning
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
for d in 1:max(ndims(ŷ), ndims(y))
Expand All @@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1

Zygote.@nograd _check_sizes
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
6 changes: 5 additions & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ function _fast_argmax(x::OneHotLike)
end
end

@nograd OneHotArray, onecold, onehot, onehotbatch
ChainRulesCore.@non_differentiable onehot(::Any...)
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
ChainRulesCore.@non_differentiable onecold(::Any...)

ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)

function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
Expand Down
7 changes: 5 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ function _restructure(m, xs)
return
end

@adjoint function _restructure(m, xs)
@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
Expand Down Expand Up @@ -603,7 +603,10 @@ true
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

@nograd modules
@nograd modules # TODO: is this correct? might fail with explicit parameters.
function ChainRulesCore.rrule(::typeof(modules), m)
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
end

isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
Expand Down

0 comments on commit 525b645

Please sign in to comment.