Skip to content

Commit 08c23f2

Browse files
committed
Yota 0.8.2, etc
1 parent e451d15 commit 08c23f2

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.3"
16-
Yota = "0.8.1"
16+
Yota = "0.8.2"
1717
Zygote = "0.6.40"
1818
julia = "1.6"
1919

docs/src/index.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ tree formed by the model and update the parameters using the gradients.
6262

6363
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
6464
but is free to mutate arrays within the old one for efficiency.
65-
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
66-
they are defensively copied when this rule is used with `update`.
6765
(The method of `apply!` above is likewise free to mutate arrays within its state;
6866
they are defensively copied when this rule is used with `update`.)
6967
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:
@@ -87,17 +85,18 @@ Yota is another modern automatic differentiation package, an alternative to Zygo
8785

8886
Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
8987
but also returns a gradient component for the loss function.
90-
To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)`
91-
or, for the Flux model above:
88+
To extract what Optimisers.jl needs, you can write (for the Flux model above):
9289

9390
```julia
9491
using Yota
9592

9693
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
97-
sum(m(x))
94+
sum(m(x)
9895
end;
99-
```
10096

97+
# Or else, this may save computing ∇image:
98+
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
99+
```
101100
102101
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
103102

0 commit comments

Comments
 (0)