Skip to content

Commit 3705b0b

Browse files
add_mixed_precision
1 parent 82c1cd8 commit 3705b0b

File tree

4 files changed

+73
-1
lines changed

4 files changed

+73
-1
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Optimisers.AccumGrad
3030
Optimisers.ClipGrad
3131
Optimisers.ClipNorm
3232
Optimisers.MixedPrecision
33+
Optimisers.add_mixed_precision
3334
Optimisers.OptimiserChain
3435
Optimisers.SignDecay
3536
Optimisers.WeightDecay

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ include("rules.jl")
2424
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
2525
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
2626
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
27-
AccumGrad, MixedPrecision
27+
AccumGrad, MixedPrecision, add_mixed_precision
2828

2929
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))
3030

src/rules.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,3 +925,46 @@ end
925925

926926
adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta))
927927
adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...))
928+
929+
930+
"""
931+
add_mixed_precision([T], tree, model) -> new_tree
932+
933+
Add mixed precision to an existing optimisers state `tree` for `model`.
934+
If `T` is not provided, `Float32` is used.
935+
936+
Each leaf of the new returned tree will contain a `MixedPrecision` rule wrapping the original rule,
937+
and the states will be preserved and converted to type `T`.
938+
"""
939+
add_mixed_precision(tree, model) = add_mixed_precision(Float32, tree, model)
940+
941+
function add_mixed_precision(T, tree, model)
942+
cache = IdDict()
943+
tree = _add_mixed_precision(T, tree, model; cache)
944+
isempty(cache) && @warn "setup found no trainable parameters in this model"
945+
return tree
946+
end
947+
948+
function _add_mixed_precision(T, tree, x; cache)
949+
ch, re = functor(tree)
950+
return mapvalue((ti, xi) -> _add_mixed_precision(T, ti, xi; cache), ch, _trainable(x))
951+
end
952+
953+
function _add_mixed_precision(T, tree::Optimisers.Leaf, x; cache)
954+
haskey(cache, tree) && return cache[tree]
955+
fT(z) = z isa AbstractFloat || isnumeric(z) ? T.(z) : z
956+
if !(tree.rule isa MixedPrecision{T})
957+
if tree.rule isa MixedPrecision # different type
958+
rulenew = MixedPrecision(T, tree.rule.rule)
959+
statenew = fmap(fT, tree.state)
960+
else
961+
rulenew = MixedPrecision(T, tree.rule)
962+
statenew = (T.(x), fmap(fT, tree.state))
963+
end
964+
treenew = Leaf(rulenew, statenew, tree.frozen)
965+
else
966+
treenew = tree
967+
end
968+
cache[tree] = treenew
969+
return treenew
970+
end

test/rules.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,31 @@ end
286286

287287
@test_throws ArgumentError OptimiserChain(MixedPrecision(Adam()))
288288
end
289+
290+
@testset "add_mixed_precision" begin
291+
d = rand(Float16, 2,2)
292+
d2 = rand(Float16, 2)
293+
model = Foo(Foo(d, d2), d)
294+
opt_state = Optimisers.setup(AdamW(), model)
295+
@test opt_state.x.x === opt_state.y
296+
@test opt_state.x.y.state[1] isa Vector{Float16}
297+
@test opt_state.x.y.state[2] isa Vector{Float16}
298+
@test opt_state.x.y.state[3] isa Tuple{Float16, Float16}
299+
300+
opt_state_new = add_mixed_precision(opt_state, model)
301+
302+
@test opt_state_new.x.x.rule isa MixedPrecision{Float32}
303+
@test opt_state_new.x.x === opt_state_new.y
304+
@test opt_state_new.x.x.state[1] isa Matrix{Float32}
305+
@test opt_state_new.x.x.state[1] model.x.x
306+
@test opt_state_new.x.y.state[2][1] isa Vector{Float32}
307+
@test opt_state_new.x.y.state[2][2] isa Vector{Float32}
308+
@test opt_state_new.x.y.state[2][3] isa Tuple{Float32, Float32}
309+
310+
opt_state_new2 = add_mixed_precision(Float64, opt_state_new, model)
311+
312+
@test opt_state_new2.x.x.rule isa MixedPrecision{Float64} # MixedPrecision{Float32} replaced
313+
@test opt_state_new2.x.x.rule.rule isa AdamW # no nesting of MixedPrecision
314+
@test opt_state_new2.x.x.state[1] isa Matrix{Float64}
315+
@test opt_state_new2.x.x.state[2][1] isa Matrix{Float64}
316+
end

0 commit comments

Comments
 (0)