Skip to content

Commit e39add7

Browse files
authored
Touch ups
1 parent b46c0ef commit e39add7

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/Optimisers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ include("rules.jl")
2323
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
2424
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
2525
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
26-
AccumGrad
26+
AccumGrad, Apollo, GradNormGrowthLimiter
2727

2828
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))
2929

src/rules.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_mi
621621
init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0)
622622

623623
function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T
624-
current_norm = Optimisers._norm(dx, 2)
624+
current_norm = _norm(dx, 2)
625625
if o.throw && !isfinite(current_norm)
626626
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))"))
627627
end
@@ -640,7 +640,6 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where
640640
ratio = current_norm / (state + o.ϵ)
641641
if ratio > o.γ
642642
λ = T((o.γ * state) / (current_norm + o.ϵ))
643-
print(":", current_norm, ":")
644643
return current_norm * λ, dx * λ
645644
else
646645
return current_norm, dx
@@ -653,8 +652,8 @@ nonfirstdims(x) = prod(size(x)[2:end])
653652
"""
654653
Apollo(η::Real, rank::Int; u = 100, sort_dims = false)
655654
Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false)
656-
Apollo(opt::Optimisers.AdamW, rank::Int; u = 100, sort_dims = false)
657-
Apollo(opt::Optimisers.AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false)
655+
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = false)
656+
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false)
658657
659658
Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage.
660659
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function

0 commit comments

Comments
 (0)