Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using Functors: functor, fmap, fmap_with_path,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

using Random: randn!

include("interface.jl")
export AbstractRule

Expand All @@ -23,7 +25,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad
AccumGrad, Apollo, GradNormGrowthLimiter

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

Expand Down
134 changes: 134 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,140 @@
return (mt, st, βt .* β), dx′
end


"""
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the default value for m correspond to the original paper (i.e. m=0 i suppose)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m=0 makes sense when this is applied to the entire model, but could be fatal when applied tensor-wise. I think it is better to have non-footgun defaults, and make it clearer that this isn't a faithful reproduction?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept a non-zero default, but I've tweaked the docs to clarify that this method isn't quite the same as in those papers. (I also switched the "scaling m by the number of parameters" to using sqrt).


Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain what it does do, mathematically, before explaining that it's different to some paper?

γ controls the maximum that the gradient norm can grow from one step to the next.

I don't know what this means without reading the code. Can you write like if norm(dx, 2) > γ * norm(dx_prev, 2) to explain the condition, and explain exactly what happens if this is violated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""
struct GradNormGrowthLimiter <: AbstractRule
γ::Float64
m::Float64 #Min grad norm, to stop a tensor getting stuck near zero
ϵ::Float64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't allow unicode field names, suggest:

Suggested change
γ::Float64
m::Float64 #Min grad norm, to stop a tensor getting stuck near zero
ϵ::Float64
gamma::Float64
mu::Float64 # Min grad norm, to stop a tensor getting stuck near zero
epsilon::Float64

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but changed the variable names to avoid eg. gamma.

throw::Bool
paramscale_min::Bool
end

GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have greek-letter keyword options, nor field names -- the API should never ask the user to type these. They are used only in documentation / as local variables. Probably the first 3 should be positional.

Bikeshedding names bit, to avoid overly long things, the constructor could be:

Suggested change
GradNormGrowthLimiter= 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min)
NormGrowLimit= 1.1, m = 1e-3, ε = 1e-8; throw = true, scale = true) = NormGrowLimit(γ, m, ε, throw, scale)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with NormGrowthCap here.


init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0)

function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T
current_norm = _norm(dx, 2)
if o.throw && !isfinite(current_norm)
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))"))

Check warning on line 626 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L626

Added line #L626 was not covered by tests
end
if state == 0
return (current_norm), dx
else
#If you're below the hard min, then don't scale
if o.paramscale_min
minthresh = o.m * sqrt(length(dx))
else
minthresh = o.m

Check warning on line 635 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L635

Added line #L635 was not covered by tests
end
if current_norm < minthresh
return current_norm, dx
end
ratio = current_norm / (state + o.ϵ)
if ratio > o.γ
λ = T((o.γ * state) / (current_norm + o.ϵ))
return current_norm * λ, dx * λ
else
return current_norm, dx
end
end
end

nonfirstdims(x) = prod(size(x)[2:end])

"""
Apollo(η::Real, rank::Int; u = 100, sort_dims = false)
Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false)
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = false)
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false)

Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage.
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
"""
struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule
opt::T1
eta::T2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why store opt and eta?

Copy link
Author

@murrellb murrellb Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally just stored opt, but then getting adjust working seemed tricky (likely a skill issue on my part though). Options were to include all the other AdamW params directly in this struct, or have an AdamW that only applies to the low-rank moments (which doesn't use eta, so its eta is redundant), and a separate eta that gets tweaked by adjust. The latter seemed better because then you can just wrap an existing AdamW in this.

Edit: another reason for storing an AdamW is that the AdamW is used instead of Apollo on regular arrays. But I just realized that now "adjust" won't work for regular arrays. I'll try figuring this out...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing an AdamW seems fine, surely we can make adjust just work through onto the inner struct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made adjust work on the inner Adam now, so have dropped the additional eta.

r::T3 #Maps non-first dims to rank
u::T4 #Subspace update frequency (T in paper)
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These have fixed types, right?

Suggested change
u::T4 #Subspace update frequency (T in paper)
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
u::Int # Subspace update frequency (T in paper)
sort_dims::Bool # Whether to swap the dims of x and dx when the second dim is smaller than the first

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

end


Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this method just be created by giving a default to eta in the next one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed via a different route.

Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you want max?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this.

Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims)
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims)
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims)

Check warning on line 675 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L672-L675

Added lines #L672 - L675 were not covered by tests

#Use the base init and apply for 1D arrays
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx)

function init(o::Apollo, x::AbstractArray{T}) where T
first_dim, second_dim = size(x,1), nonfirstdims(x)
if o.sort_dims && second_dim < first_dim
first_dim, second_dim = second_dim, first_dim

Check warning on line 684 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L684

Added line #L684 was not covered by tests
end
rank = o.r(second_dim)
P = similar(x, rank, first_dim)
randn!(P)
P .*= T(sqrt(1/rank))
((similar(x, rank, second_dim) .= 0, similar(x, rank, second_dim) .= 0, o.opt.beta), 1, P)
end


function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
swapped = false
original_size = size(x)
x = reshape(x, size(x,1), nonfirstdims(x))

dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape.
dx = reshape(dx, size(x,1), nonfirstdims(x))
Comment on lines +707 to +708
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to materialize in matrix case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For everything except the whatever comes in during the "gradient type" test you don't need materialize. I wasn't 100% sure exactly what is coming in during those tests, so wasn't sure how to separate them from regular matrix/tensors. What do you suggest here?


first_dim, second_dim = size(x,1), size(x,2)
if o.sort_dims && second_dim < first_dim
first_dim, second_dim = second_dim, first_dim
x = x'
dx = dx'
swapped = true

Check warning on line 707 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L704-L707

Added lines #L704 - L707 were not covered by tests
end
(mt, vt, βt), t, P = state
η = T(o.eta) #This is what will get modified by adjust
λ = T(o.opt.lambda)
β = T.(o.opt.beta)
ϵ = T(o.opt.epsilon)
βt = T.(βt)
if mod(t, o.u) == 0
rank = o.r(second_dim)
randn!(P)
P .*= T(sqrt(1/rank))
end
R = P * dx
@.. mt = β[1] * mt + (1 - β[1]) * R
@.. vt = β[2] * vt + (1 - β[2]) * abs2(R)
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ)
dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines allocate a lot.

Rhat isn't used?

For the rest maybe it can be something like

sum1R2 = sum(abs2, R; dims=1)  # it's already the right shape, no need for [:] & reshape(s, 1, :)?
s = @. sqrt(sum1R2) / sqrt(Rhat + ϵ)
dx′′ = @lazy η * (dx * s) + λ * x  # one fused broadcast

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got something like this working, but the @lazy breaks things, so omitted for now.

if swapped
dx′′ = dx′′'

Check warning on line 727 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L727

Added line #L727 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dx′′ = dx′′'
dx′′ = transpose(dx′′)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this sort of branching introduces type instability. IDK if we care but perhaps worth some thought. Maybe there's a nicer way to just store everything transposed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an optimization we can figure out later if it becomes an issue?

end
return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size)
end

#Notes: chuck the AdamW from the struct, so that adjust will just work.



"""
WeightDecay(λ = 5e-4)
WeightDecay(; [lambda])
Expand Down
3 changes: 2 additions & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ RULES = [
# All the rules at default settings:
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Apollo(),
# A few chained combinations:
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
OptimiserChain(GradNormGrowthLimiter(1.1), Apollo()),
# Not the default:
RMSProp(centred = true), AdamW(couple=false),
]
Expand Down
Loading