|
925 | 925 |
|
926 | 926 | adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta)) |
927 | 927 | adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...)) |
| 928 | + |
| 929 | + |
| 930 | +""" |
| 931 | + add_mixed_precision([T], tree, model) -> new_tree |
| 932 | +
|
| 933 | +Add mixed precision to an existing optimisers state `tree` for `model`. |
| 934 | +If `T` is not provided, `Float32` is used. |
| 935 | +
|
| 936 | +Each leaf of the new returned tree will contain a `MixedPrecision` rule wrapping the original rule, |
| 937 | +and the states will be preserved and converted to type `T`. |
| 938 | +""" |
| 939 | +add_mixed_precision(tree, model) = add_mixed_precision(Float32, tree, model) |
| 940 | + |
| 941 | +function add_mixed_precision(T, tree, model) |
| 942 | + cache = IdDict() |
| 943 | + tree = _add_mixed_precision(T, tree, model; cache) |
| 944 | + isempty(cache) && @warn "setup found no trainable parameters in this model" |
| 945 | + return tree |
| 946 | +end |
| 947 | + |
| 948 | +function _add_mixed_precision(T, tree, x; cache) |
| 949 | + ch, re = functor(tree) |
| 950 | + return mapvalue((ti, xi) -> _add_mixed_precision(T, ti, xi; cache), ch, _trainable(x)) |
| 951 | +end |
| 952 | + |
| 953 | +function _add_mixed_precision(T, tree::Optimisers.Leaf, x; cache) |
| 954 | + haskey(cache, tree) && return cache[tree] |
| 955 | + fT(z) = z isa AbstractFloat || isnumeric(z) ? T.(z) : z |
| 956 | + if !(tree.rule isa MixedPrecision{T}) |
| 957 | + if tree.rule isa MixedPrecision # different type |
| 958 | + rulenew = MixedPrecision(T, tree.rule.rule) |
| 959 | + statenew = fmap(fT, tree.state) |
| 960 | + else |
| 961 | + rulenew = MixedPrecision(T, tree.rule) |
| 962 | + statenew = (T.(x), fmap(fT, tree.state)) |
| 963 | + end |
| 964 | + treenew = Leaf(rulenew, statenew, tree.frozen) |
| 965 | + else |
| 966 | + treenew = tree |
| 967 | + end |
| 968 | + cache[tree] = treenew |
| 969 | + return treenew |
| 970 | +end |
0 commit comments