Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 5, 2022
1 parent 5e8009c commit 923eca0
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
9 changes: 8 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,11 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,


# v0.13 deprecations
@deprecate Maxout(layers::Tuple) Maxout(layers...)
@deprecate Maxout(layers::Tuple) Maxout(layers...)

function Broadcast.broadcasted(f::Recur, args...)
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
Re-writing this as a comprehension would be better.""", :broadcasted)
map(f, args...) # map isn't really safe either, but
end
5 changes: 0 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)

# TODO move to ChainRulesCore?
@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
5 changes: 3 additions & 2 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ for mathematical details.
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent())
return ctc_loss(ŷ, y), ctc_loss_pullback
tmp = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
return tmp.loss, ctc_loss_pullback
end


Expand Down
2 changes: 1 addition & 1 deletion src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end

ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # is this good enough?
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)

# This can be made an error in Flux v0.13, for now just a warning
Expand Down
6 changes: 4 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -788,8 +788,10 @@ L2 (generic function with 1 method)
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

@nograd modules
ChainRulesCore.@non_differentiable modules(::Any) # is this correct?
@nograd modules # TODO: is this correct? might fail with explicit parameters.
function ChainRulesCore.rrule(::typeof(modules), m)
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
end

isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
Expand Down

0 comments on commit 923eca0

Please sign in to comment.