You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Gradient norm growth limiter from Chen et al.(https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al.(https://arxiv.org/pdf/2412.05270).
607
-
With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary.
608
-
This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero.
609
-
This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`).
606
+
Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
607
+
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow
608
+
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor
609
+
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`).
610
610
"""
611
611
struct GradNormGrowthLimiter <:AbstractRule
612
612
γ::Float64
@@ -630,7 +630,7 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where
630
630
else
631
631
#If you're below the hard min, then don't scale
632
632
if o.paramscale_min
633
-
minthresh = o.m *length(dx)
633
+
minthresh = o.m *sqrt(length(dx))
634
634
else
635
635
minthresh = o.m
636
636
end
@@ -659,19 +659,20 @@ Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks mome
659
659
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function
660
660
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
661
661
"""
662
-
struct Apollo{T1} <:AbstractRule
662
+
struct Apollo{T1, T2, T3, T4, T5} <:AbstractRule
663
663
opt::T1
664
-
r::Function#Maps non-first dims to rank
665
-
u::Int#Subspace update frequency (T in paper)
666
-
sort_dims::Bool#Whether to swap the dims of x and dx when the second dim is smaller than the first
664
+
eta::T2
665
+
r::T3#Maps non-first dims to rank
666
+
u::T4#Subspace update frequency (T in paper)
667
+
sort_dims::T5#Whether to swap the dims of x and dx when the second dim is smaller than the first
667
668
end
668
669
669
670
670
-
Apollo() =Apollo(AdamW(0.001), dim ->ceil(Int, sqrt(dim)), 100, true)
671
-
Apollo(η::Real, rank::Int; u =100, sort_dims =true) =Apollo(AdamW(η), dim ->max(dim, rank), u, sort_dims)
672
-
Apollo(η::Real; rank_function::Function= dim ->ceil(Int, sqrt(dim)), u =100, sort_dims =true) =Apollo(AdamW(η), rank_function, u, sort_dims)
673
-
Apollo(opt::AdamW, rank::Int; u =100, sort_dims =true) =Apollo(AdamW(η), dim ->max(dim, rank), u, sort_dims)
674
-
Apollo(opt::AdamW; rank_function::Function= dim ->ceil(Int, sqrt(dim)), u =100, sort_dims =true) =Apollo(opt, rank_function, u, sort_dims)
671
+
Apollo() =Apollo(AdamW(0.001), 0.001, dim ->ceil(Int, sqrt(dim)), 100, true)
672
+
Apollo(η::Real, rank::Int; u =100, sort_dims =true) =Apollo(AdamW(η), η, dim ->max(dim, rank), u, sort_dims)
673
+
Apollo(η::Real; rank_function::Function= dim ->ceil(Int, sqrt(dim)), u =100, sort_dims =true) =Apollo(AdamW(η), η, rank_function, u, sort_dims)
674
+
Apollo(opt::AdamW, rank::Int; u =100, sort_dims =true) =Apollo(opt, opt.eta, dim ->max(dim, rank), u, sort_dims)
675
+
Apollo(opt::AdamW; rank_function::Function= dim ->ceil(Int, sqrt(dim)), u =100, sort_dims =true) =Apollo(opt, opt.eta, rank_function, u, sort_dims)
675
676
676
677
#Use the base init and apply for 1D arrays
677
678
init(o::Apollo, x::AbstractArray{T,1}) where T =init(o.opt, x)
@@ -706,7 +707,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
706
707
swapped =true
707
708
end
708
709
(mt, vt, βt), t, P = state
709
-
η =T(o.opt.eta)
710
+
η =T(o.eta)#This is what will get modified by adjust
710
711
λ =T(o.opt.lambda)
711
712
β =T.(o.opt.beta)
712
713
ϵ =T(o.opt.epsilon)
@@ -728,6 +729,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
0 commit comments