Skip to content

Commit

Permalink
Yota 0.8.2, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 31, 2022
1 parent e451d15 commit 08c23f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Functors = "0.3"
Yota = "0.8.1"
Yota = "0.8.2"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
11 changes: 5 additions & 6 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.
(The method of `apply!` above is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.)
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:
Expand All @@ -87,17 +85,18 @@ Yota is another modern automatic differentiation package, an alternative to Zygo

Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
but also returns a gradient component for the loss function.
To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)`
or, for the Flux model above:
To extract what Optimisers.jl needs, you can write (for the Flux model above):

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x))
sum(m(x)
end;
```

# Or else, this may save computing ∇image:
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
```
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
Expand Down

0 comments on commit 08c23f2

Please sign in to comment.