Skip to content

Commit

Permalink
Merge pull request #2 from FluxML/master
Browse files Browse the repository at this point in the history
Catching up
  • Loading branch information
murrellb authored Dec 19, 2024
2 parents 346669f + 34250b2 commit a655ca0
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
24 changes: 19 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Optimisers.jl

Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.

This was written as the new training system for [Flux.jl](https://github.com/FluxML/Flux.jl) neural networks,
and also used by [Lux.jl](https://github.com/LuxDL/Lux.jl).
But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl).

## Installation

In the Julia REPL, type
```julia
]add Optimisers
```

or
```julia-repl
julia> import Pkg; Pkg.add("Optimisers")
```

## An optimisation rule

A new optimiser must overload two functions, [`apply!`](@ref Optimisers.apply!) and [`init`](@ref Optimisers.init).
Expand Down Expand Up @@ -38,7 +56,6 @@ state for every trainable array. Then at each step, [`update`](@ref Optimisers.u
to adjust the model:

```julia

using Flux, Metalhead, Zygote, Optimisers

model = Metalhead.ResNet(18) |> gpu # define a model to train
Expand All @@ -54,7 +71,6 @@ end;

state_tree, model = Optimisers.update(state_tree, model, ∇model);
@show sum(model(image)); # reduced

```

Notice that a completely new instance of the model is returned. Internally, this
Expand Down Expand Up @@ -91,7 +107,6 @@ Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
identical trees of nested `NamedTuple`s.)

```julia

using Lux, Boltz, Zygote, Optimisers

lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
Expand All @@ -113,7 +128,6 @@ opt_state, params = Optimisers.update!(opt_state, params, ∇params);

y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced

```

Besides the parameters stored in `params` and gradually optimised, any other model state
Expand Down Expand Up @@ -297,7 +311,7 @@ similarly to what [`destructure`](@ref Optimisers.destructure) does but without
concatenating the arrays into a flat vector.
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:

```julia
```julia-repl
julia> using Flux, Optimisers
julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ This is what [`destructure`](@ref Optimisers.destructure) returns, and `re(p)` w
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
# Example
```julia
```julia-repl
julia> using Flux, Optimisers
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
Expand Down
10 changes: 5 additions & 5 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,11 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
"""
struct AdamW{T1,T2,T3,T4} <: AbstractRule
eta::T1
beta::T2
epsilon::T3
lambda::T4
struct AdamW <: AbstractRule
eta::Float64
beta::Tuple{Float64, Float64}
lambda::Float64
epsilon::Float64
couple::Bool
end

Expand Down
6 changes: 3 additions & 3 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
```
Expand Down

0 comments on commit a655ca0

Please sign in to comment.