Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
48 changes: 48 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,54 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
return (mt, st, βt .* β), dx′
end


"""
Apollo(opt, r, u)

Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage.
`opt` is an AdamW optimizer, `r` is the random projection rank (smaller for lower memory use), and `u` is the random projection update interval.
"""
struct Apollo{T1} <: AbstractRule
opt::T1
r::Int #Subspace rank
u::Int #Subspace update frequency (T in paper)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this must be AdamW, then you could do this, which supplies defaults:

Suggested change
struct Apollo{T1} <: AbstractRule
opt::T1
r::Int #Subspace rank
u::Int #Subspace update frequency (T in paper)
end
@def struct Apollo{T1} <: AbstractRule
opt = AdamW()
r = 10 # Subspace rank
u = 10 # Subspace update frequency (T in paper)
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted the user to be able to pass in either an Int or a function for rank (the latter where they can scale the rank based on the dim), so I've written some custom constructors with defaults instead of this approach.


#Use the base init and apply for 1D arrays
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx)

function init(o::Apollo, x::AbstractArray{T,2}) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For arrays of >2D (e.g. weight of Conv), should there be methods to reshape to matrix & reshape back?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is on my list. There is also this pesky assertion about the dimension ordering that means that some matrices will have to be transposed:

image

but I'm not sure a lazy transpose as W comes in and goes out will be optimal - might have to write a new path for those. I suspect we (as in "humanity") don't know if this actually helps, so I'm go make sure this can be controlled by a flag.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these are included now.

rank = min(o.r, ceil(Int, size(x,2) / 2))
P = randn(T, rank, size(x,1)) .* T(1/rank)
((similar(x, rank, size(x,2)) .= 0, similar(x, rank, size(x,2)) .= 0, o.opt.beta), 0, P)
end

function apply!(o::Apollo, state, x::AbstractArray{T,2}, dx) where T
(mt, vt, βt), t, P = state
η = T(o.opt.eta)
λ = T(o.opt.lambda)
β = T.(o.opt.beta)
ϵ = T(o.opt.epsilon)
if mod(t, o.u) == 100
rank = min(o.r, ceil(Int, size(x,2) / 2))
@show rank, typeof(rank)
P = randn(T, rank, size(x,1)) .* T(1/rank)
end
R = P * dx
Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * R
Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(R)
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ)
S = Diagonal(s)
dx′′ = η * dx * S + λ * x
return ((mt, vt, βt .* β), t+1, P), dx′′
end





"""
WeightDecay(λ = 5e-4)
WeightDecay(; [lambda])
Expand Down
Loading