diff --git a/Project.toml b/Project.toml index 3f92e8e9..33cc3487 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3" -Yota = "0.8.1" +Yota = "0.8.2" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index 5a1e5210..863428b7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -62,8 +62,6 @@ tree formed by the model and update the parameters using the gradients. There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, but is free to mutate arrays within the old one for efficiency. -The method of `apply!` for each rule is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. (The method of `apply!` above is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`.) For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`: @@ -87,17 +85,18 @@ Yota is another modern automatic differentiation package, an alternative to Zygo Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`) but also returns a gradient component for the loss function. -To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)` -or, for the Flux model above: +To extract what Optimisers.jl needs, you can write (for the Flux model above): ```julia using Yota loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x - sum(m(x)) + sum(m(x) end; -``` +# Or else, this may save computing ∇image: +loss, (_, ∇model) = grad(m -> sum(m(image)), model); +``` ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)