Skip to content

AdamW optimizer implemented incorrectly - weight decay does not incorporate learning rate #182

Closed as not planned
@BioTurboNick

Description

@BioTurboNick

In Optimisers.jl, AdamW is implemented as an OptimiserChain of Adam and WeightDecay:

Optimisers.jl/src/rules.jl

Lines 510 to 514 in c2ae321

AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))

WeightDecay here simply multiplies the decay value by the parameter:

Optimisers.jl/src/rules.jl

Lines 569 to 574 in c2ae321

function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
λ = T(o.lambda)
dx′ = @lazy dx + λ * x
return state, dx′
end

In AdamW, and indeed in PyTorch, the WeightDecay value needs to be multiplied by the learning rate too:
image
From: https://arxiv.org/pdf/1711.05101

This appears to be the source of some great frustration for me, as I was observing extreme misbehavior from the model I've been trying to port from PyTorch.

The following optimiser produces the correct behavior:


Optimisers.@def struct LearningWeightDecay <: Optimisers.AbstractRule
  lambda = 5e-4
  eta = 0.001
end

Optimisers.init(o::LearningWeightDecay, x::AbstractArray) = nothing

function Optimisers.apply!(o::LearningWeightDecay, state, x::AbstractArray{T}, dx) where T
  λ, η = T(o.lambda), T(o.eta)
  dx′ = Optimisers.@lazy dx + η * λ * x

  return state, dx′
end

CorrectAdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
  Optimisers.OptimiserChain(Optimisers.Adam(η, β, ϵ), LearningWeightDecay(λ, η))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions