-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196
base: master
Are you sure you want to change the base?
Changes from 2 commits
bb94d68
acbe8e3
d358026
8a05289
b46c0ef
e39add7
43d30c6
b282c35
ca2ae0a
6aa32c1
d9637c6
c75142f
b95fd3c
97b0332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -599,6 +599,54 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T | |
return (mt, st, βt .* β), dx′ | ||
end | ||
|
||
|
||
""" | ||
Apollo(opt, r, u) | ||
|
||
Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. | ||
`opt` is an AdamW optimizer, `r` is the random projection rank (smaller for lower memory use), and `u` is the random projection update interval. | ||
""" | ||
struct Apollo{T1} <: AbstractRule | ||
opt::T1 | ||
r::Int #Subspace rank | ||
u::Int #Subspace update frequency (T in paper) | ||
end | ||
|
||
#Use the base init and apply for 1D arrays | ||
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) | ||
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx) | ||
|
||
function init(o::Apollo, x::AbstractArray{T,2}) where T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For arrays of >2D (e.g. weight of Conv), should there be methods to reshape to matrix & reshape back? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that is on my list. There is also this pesky assertion about the dimension ordering that means that some matrices will have to be transposed: but I'm not sure a lazy transpose as W comes in and goes out will be optimal - might have to write a new path for those. I suspect we (as in "humanity") don't know if this actually helps, so I'm go make sure this can be controlled by a flag. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both of these are included now. |
||
rank = min(o.r, ceil(Int, size(x,2) / 2)) | ||
P = randn(T, rank, size(x,1)) .* T(1/rank) | ||
((similar(x, rank, size(x,2)) .= 0, similar(x, rank, size(x,2)) .= 0, o.opt.beta), 0, P) | ||
end | ||
|
||
function apply!(o::Apollo, state, x::AbstractArray{T,2}, dx) where T | ||
(mt, vt, βt), t, P = state | ||
η = T(o.opt.eta) | ||
λ = T(o.opt.lambda) | ||
β = T.(o.opt.beta) | ||
ϵ = T(o.opt.epsilon) | ||
if mod(t, o.u) == 100 | ||
rank = min(o.r, ceil(Int, size(x,2) / 2)) | ||
@show rank, typeof(rank) | ||
P = randn(T, rank, size(x,1)) .* T(1/rank) | ||
end | ||
R = P * dx | ||
Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * R | ||
Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(R) | ||
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) | ||
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) | ||
S = Diagonal(s) | ||
dx′′ = η * dx * S + λ * x | ||
return ((mt, vt, βt .* β), t+1, P), dx′′ | ||
end | ||
|
||
|
||
|
||
|
||
|
||
""" | ||
WeightDecay(λ = 5e-4) | ||
WeightDecay(; [lambda]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this must be AdamW, then you could do this, which supplies defaults:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted the user to be able to pass in either an Int or a function for rank (the latter where they can scale the rank based on the dim), so I've written some custom constructors with defaults instead of this approach.