Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Muon #203

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module Optimisers

using Functors: functor, fmap, fmap_with_path,
using Functors: functor, fmap, fmap_with_path,
KeyPath, haskeypath, getkeypath,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra
import LinearAlgebra: norm

include("interface.jl")
export AbstractRule
Expand All @@ -23,7 +24,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad
AccumGrad, Muon

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

Expand Down
90 changes: 83 additions & 7 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
function apply!(o::Rprop, state, x::AbstractArray{T}, dx) where T
ℓ, Γ = T.(o.ell), T.(o.gamma)
g, η = state

η = broadcast(g, η, dx) do g, η, dx
g * dx > 0 ? min(η * ℓ[2], Γ[2]) : g * dx < 0 ? max(η * ℓ[1], Γ[1]) : η
end
Expand Down Expand Up @@ -256,6 +256,7 @@ function apply!(o::Lion, state, x::AbstractArray{T}, dx) where T
return state, dx′
end


"""
RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
RAdam(; [eta, beta, epsilon])
Expand Down Expand Up @@ -599,14 +600,89 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
return (mt, st, βt .* β), dx′
end

nonfirstdims(x) = prod(size(x)[2:end])

"""
Muon(opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), η = 0.02, μ = 0.95, λ = 0.01, fallback = Returns(false))
Muon(; [opt, eta, mu, lambda, fallback])

Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon)

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step,
in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration.

# Parameters
- Fallback optimizer (`opt`): Optimizer to use for 1D parameters or when the `fallback` function returns true
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights
- Momentum (`μ == mu`): Controls the acceleration of gradient descent in the prominent direction
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
- Fallback function (`fallback`): Function to control when, in addition to 1D arrays, the fallback optimizer should be used. Will be passed the parameter array and must return a boolean.

Note: Works best with large batch sizes and may not be suitable for fine-tuning.
In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads.

`Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio,
but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately).
Comment on lines +619 to +625
Copy link
Member

@mcabbott mcabbott Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What it seems you really want is for setup to use Muon on some arrays, and AdamW on others. But instead this is rolled into this particular meta-optimisation rule, which is also called Muon. Maybe we should think about how to do that in a bit more generality?

One spelling would be setup(OptimiserIfElse(fun, AdamW(), Muon()), model) which does fun(x) ? init(AdamW(), x) : init(Muon(), x), with some stuct OptimiserIfElse <: AbstractRule. But it's a "fake rule" which is digested at setup time.

Another would be setup(fun::Function, model::Any) which is like

opt_state = setup(model) do x
  ndims(x) == 1 ? AdamW() : Muon()
end

I'm sure we batted around such ideas when writing this package, but nobody had a concrete need. One thing we wondered was whether this which-rule function ought to see just x or, for instance, the field name, or the layer's type, or what? ndims(x) == 1 is a way of selecting bias but can't distinguish weight matrices from different layers.

Edit, a 3rd way is just to let you handle it. Too obscure, even if documented? Not sure.

opt_state = fmapstructure(model; exclude=isnumeric) do x
  rule = ndims(x) == 1 ? AdamW() : Muon()
  setup(rule, x)
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly agree that this functionality would be good to have in general, but as you say it isn't quite clear what to switch on.

But in this case, when people say "Muon" they mean "Muon for >2D with an AdamW fallback for 1D" - see https://github.com/KellerJordan/Muon/blob/master/muon.py
And the only things they use to switch are directly inferable from the tensor itself, which we have access to inside the optimiser so we can get the same behavior. So I think it makes sense to, here, just call this Muon? Also, it supports the same (as the python version) way of adjusting the two eta values during warmup/cooldown (retaining their ratio) saving the user a little effort if they want to use it as-is.

"""
struct Muon <: AbstractRule
opt::AbstractRule
eta::Float64
mu::Float64
lambda::Float64
fallback::Function
end

Muon(;opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), eta = 0.02, mu = 0.95, lambda = 0.01, fallback = x -> false) = Muon(opt, eta, mu, lambda, fallback)

function init(o::Muon, x::AbstractArray)
if nonfirstdims(x) == 1 || o.fallback(x)
return init(o.opt, x)
else
return zero(x)
end
end

function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T
if nonfirstdims(x) == 1 || o.fallback(x)
return apply!(o.opt, state, x, dx)
else
η, μ, λ = T(o.eta), T(o.mu), T(o.lambda)
@.. state = μ * state + dx
Ot = _newton_schulz5(μ .* state .+ dx) * T(sqrt(max(1, size(x,1)/nonfirstdims(x))))
dx′ = @lazy η * (Ot + λ * x)
return state, dx′
end
end

function _inner_newton_schulz5(X::AbstractMatrix{T}) where T
a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0))
for _ in 1:5
A = X * X'
B = b * A + c * A * A
X = a * X + B * X
end
X
end
function _newton_schulz5(G::AbstractMatrix{T}) where T
X = G / (norm(G) + eps(T))
if size(G, 1) > size(G, 2)
transpose(_inner_newton_schulz5(transpose(X)))
else
_inner_newton_schulz5(X)
end
end
_newton_schulz5(G::AbstractArray) = reshape(_newton_schulz5(reshape(G, size(G,1), :)), size(G))

adjust(r::Muon, η::Real) = adjust(r, eta = η, opt = adjust(r.opt, eta = (r.opt.eta / r.eta) * η))

"""
WeightDecay(λ = 5e-4)
WeightDecay(; [lambda])

Implements ``L_2`` regularisation, also known as ridge regression,
Implements ``L_2`` regularisation, also known as ridge regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).

It does this by adding `λ .* x` to the gradient. This is equivalent to adding
It does this by adding `λ .* x` to the gradient. This is equivalent to adding
`λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.

See also [`SignDecay`] for ``L_1`` normalisation.
Expand Down Expand Up @@ -644,7 +720,7 @@ function adjust(r::WeightDecay; gamma = nothing, kw...)
Implements ``L_1`` regularisation, also known as LASSO regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).

It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
`λ * sum(abs, x) == λ * norm(x, 1)` to the loss.

See also [`WeightDecay`] for ``L_2`` normalisation.
Expand Down Expand Up @@ -783,7 +859,7 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state)
if dx′ isa Zero
return (states′..., state), dx′
else
else
state′, dx′ = apply!(opt, state, x, dx′, dxs...)
return (states′..., state′), dx′
end
Expand Down Expand Up @@ -831,10 +907,10 @@ julia> m # n=2 gradients applied at once
"""
struct AccumGrad <: AbstractRule
n::Int

function AccumGrad(n::Int)
n > 0 || throw(ArgumentError("AccumGrad must accumulate at least one gradient"))
return new(n)
return new(n)
end
end

Expand Down
12 changes: 6 additions & 6 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RULES = [
# All the rules at default settings:
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Muon(),
# A few chained combinations:
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
Expand Down Expand Up @@ -183,7 +183,7 @@ end
# The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2),
# These weren't in Flux PR:
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2),
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2),
]
# Our "model" is just a complex number
model = (w = zeros(ComplexF64, 1),)
Expand Down Expand Up @@ -226,7 +226,7 @@ end
@test static_loss(static_model) < last_loss
last_loss = static_loss(static_model)
end
@test static_loss(static_model) < 1.9
@test static_loss(static_model) < 1.9
end
end

Expand Down Expand Up @@ -254,16 +254,16 @@ end
g1 = rand(5)
tree, x1 = Optimisers.update(tree, x, g1)
@test x1 ≈ x
@test x1 ≈ x0
@test x1 ≈ x0
g2 = rand(5)
tree, x2 = Optimisers.update(tree, x1, g2)
@test x2 ≈ x
@test x2 ≈ x0
@test x2 ≈ x0
g3 = rand(5)
tree, x3 = Optimisers.update(tree, x2, g3)
@test x3 ≈ x0 .- lr .* (g1 .+ g2 .+ g3) ./ 3
g4 = rand(5)

tree, x4 = Optimisers.update(tree, x3, g4)
@test x4 ≈ x3
end
Expand Down
Loading