Skip to content

Commit d9637c6

Browse files
committed
Give Apollo its own eta for adjust, and use sqrt(#params) for GradNormGrowthLimiter
1 parent 6aa32c1 commit d9637c6

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

src/rules.jl

+19-15
Original file line numberDiff line numberDiff line change
@@ -603,10 +603,10 @@ end
603603
"""
604604
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true)
605605
606-
Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270).
607-
With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary.
608-
This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero.
609-
This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`).
606+
Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
607+
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow
608+
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor
609+
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`).
610610
"""
611611
struct GradNormGrowthLimiter <: AbstractRule
612612
γ::Float64
@@ -630,7 +630,7 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where
630630
else
631631
#If you're below the hard min, then don't scale
632632
if o.paramscale_min
633-
minthresh = o.m * length(dx)
633+
minthresh = o.m * sqrt(length(dx))
634634
else
635635
minthresh = o.m
636636
end
@@ -659,19 +659,20 @@ Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks mome
659659
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function
660660
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
661661
"""
662-
struct Apollo{T1} <: AbstractRule
662+
struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule
663663
opt::T1
664-
r::Function #Maps non-first dims to rank
665-
u::Int #Subspace update frequency (T in paper)
666-
sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first
664+
eta::T2
665+
r::T3 #Maps non-first dims to rank
666+
u::T4 #Subspace update frequency (T in paper)
667+
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
667668
end
668669

669670

670-
Apollo() = Apollo(AdamW(0.001), dim -> ceil(Int, sqrt(dim)), 100, true)
671-
Apollo::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims)
672-
Apollo::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), rank_function, u, sort_dims)
673-
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims)
674-
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, rank_function, u, sort_dims)
671+
Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true)
672+
Apollo::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims)
673+
Apollo::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims)
674+
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims)
675+
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims)
675676

676677
#Use the base init and apply for 1D arrays
677678
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
@@ -706,7 +707,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
706707
swapped = true
707708
end
708709
(mt, vt, βt), t, P = state
709-
η = T(o.opt.eta)
710+
η = T(o.eta) #This is what will get modified by adjust
710711
λ = T(o.opt.lambda)
711712
β = T.(o.opt.beta)
712713
ϵ = T(o.opt.epsilon)
@@ -728,6 +729,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
728729
return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size)
729730
end
730731

732+
#Notes: chuck the AdamW from the struct, so that adjust will just work.
733+
734+
731735

732736
"""
733737
WeightDecay(λ = 5e-4)

0 commit comments

Comments
 (0)