diff --git a/src/rules.jl b/src/rules.jl index 05d52af..47c2d8c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -531,14 +531,14 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`] struct AdamW{T1,T2,T3,T4} <: AbstractRule eta::T1 beta::T2 - epsilon::T3 lambda::T4 + epsilon::T3 couple::Bool end function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true) η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) - AdamW(η, β, ϵ, λ, couple) + AdamW(η, β, λ, ϵ, couple) end AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) =