|
41 | 41 |
|
42 | 42 | julia> grad3 = gradient((x,p) -> p(x), 5, poly3s) |
43 | 43 | (2.0, (θ3 = [1.0, 5.0, 25.0],))</code></pre><p>The first entry is <code>∂f/∂x</code> as before, but the second entry is more interesting. For <code>poly2</code>, we get <code>∂f/∂θ</code> as <code>grad2[2]</code> directly. It is a vector, because <code>θ</code> is a vector, and has elements <code>[∂f/∂θ[1], ∂f/∂θ[2], ∂f/∂θ[3]]</code>.</p><p>For <code>poly3s</code>, however, we get a <code>NamedTuple</code> whose fields correspond to those of the struct <code>Poly3</code>. This is called a <em>structural gradient</em>. And the nice thing about them is that they work for arbitrarily complicated structures, for instance:</p><pre><code class="language-julia-repl hljs">julia> grad4 = gradient(|>, 5, poly4) |
44 | | -(1.0, (outer = (θ3 = [1.0, 17.5, 306.25],), inner = (θ3 = [0.5, 2.5, 12.5],)))</code></pre><p>Here <code>grad4.inner.θ3</code> corresponds to <code>poly4.inner.θ3</code>. These matching nested structures are at the core of how Flux works.</p><div class="admonition is-info" id="Implicit-gradients-14b778e96ba14cca"><header class="admonition-header">Implicit gradients<a class="admonition-anchor" href="#Implicit-gradients-14b778e96ba14cca" title="Permalink"></a></header><div class="admonition-body"><p>Earlier versions of Flux used a different way to relate parameters and gradients, which looks like this:</p><pre><code class="language-julia hljs">g1 = gradient(() -> poly1(5), Params([θ])) |
| 44 | +(1.0, (outer = (θ3 = [1.0, 17.5, 306.25],), inner = (θ3 = [0.5, 2.5, 12.5],)))</code></pre><p>Here <code>grad4[2].inner.θ3</code> corresponds to <code>poly4.inner.θ3</code>. These matching nested structures are at the core of how Flux works.</p><div class="admonition is-info" id="Implicit-gradients-14b778e96ba14cca"><header class="admonition-header">Implicit gradients<a class="admonition-anchor" href="#Implicit-gradients-14b778e96ba14cca" title="Permalink"></a></header><div class="admonition-body"><p>Earlier versions of Flux used a different way to relate parameters and gradients, which looks like this:</p><pre><code class="language-julia hljs">g1 = gradient(() -> poly1(5), Params([θ])) |
45 | 45 | g1[θ] == [1.0, 5.0, 25.0]</code></pre><p>Here <code>Params</code> is a set of references to global variables using <code>objectid</code>, and <code>g1 isa Grads</code> is a dictionary from these to their gradients. This method of <code>gradient</code> takes a zero-argument function, which only <em>implicitly</em> depends on <code>θ</code>.</p></div></div><h3><img src="../../../assets/zygote-crop.png" width="40px"/> <a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a></h3><p>Flux's <a href="../../../reference/training/enzyme/#Flux.gradient-Tuple{Any, Vararg{Union{EnzymeCore.Const, EnzymeCore.Duplicated}}}"><code>gradient</code></a> function by default calls a companion packages called <a href="https://github.com/FluxML/Zygote.jl">Zygote</a>. Zygote performs source-to-source automatic differentiation, meaning that <code>gradient(f, x)</code> hooks into Julia's compiler to find out what operations <code>f</code> contains, and transforms this to produce code for computing <code>∂f/∂x</code>.</p><p>Zygote can in principle differentiate almost any Julia code. However, it's not perfect, and you may eventually want to read its <a href="https://fluxml.ai/Zygote.jl/dev/limitations/">page about limitations</a>. In particular, a major limitation is that mutating an array is not allowed.</p><p>Flux can also be used with other automatic differentiation (AD) packages. It was originally written using <a href="https://github.com/FluxML/Tracker.jl">Tracker</a>, a more traditional operator-overloading approach. The future might be <a href="https://github.com/EnzymeAD/Enzyme.jl">Enzyme</a>, and Flux now builds in an easy way to use this instead, turned on by wrapping the model in <code>Duplicated</code>. (For details, see the <a href="../../../reference/training/enzyme/#autodiff-enzyme">Enzyme page</a> in the manual.)</p><pre><code class="language-julia hljs">julia> using Enzyme: Const, Duplicated |
46 | 46 |
|
47 | 47 | julia> grad3e = Flux.gradient((x,p) -> p(x), Const(5.0), Duplicated(poly3s)) |
|
95 | 95 | Flux.train!((m,x,y) -> (m(x) - y)^2, model3, data, Descent(0.01)) |
96 | 96 | end</code></pre><p>The same code will also work with <code>model1</code> or <code>model2</code> instead. Here's how to plot the desired and actual outputs:</p><pre><code class="language-julia hljs">using Plots |
97 | 97 | plot(x -> 2x-x^3, -2, 2, label="truth") |
98 | | -scatter!(x -> model3([x]), -2:0.1f0:2, label="fitted")</code></pre><p>More detail about what exactly the function <code>train!</code> is doing, and how to use rules other than simple <a href="../../../reference/training/optimisers/#Optimisers.Descent"><code>Descent</code></a>, is what the next page in this guide is about: <a href="../../training/training/#man-training">training</a>.</p></article><nav class="docs-footer"><a class="docs-footer-prevpage" href="../overview/">« Fitting a Line</a><a class="docs-footer-nextpage" href="../../training/training/">Training »</a><div class="flexbox-break"></div><p class="footer-message">Powered by <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> and the <a href="https://julialang.org/">Julia Programming Language</a>.</p></nav></div><div class="modal" id="documenter-settings"><div class="modal-background"></div><div class="modal-card"><header class="modal-card-head"><p class="modal-card-title">Settings</p><button class="delete"></button></header><section class="modal-card-body"><p><label class="label">Theme</label><div class="select"><select id="documenter-themepicker"><option value="auto">Automatic (OS)</option><option value="documenter-light">documenter-light</option><option value="documenter-dark">documenter-dark</option><option value="catppuccin-latte">catppuccin-latte</option><option value="catppuccin-frappe">catppuccin-frappe</option><option value="catppuccin-macchiato">catppuccin-macchiato</option><option value="catppuccin-mocha">catppuccin-mocha</option></select></div></p><hr/><p>This document was generated with <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> version 1.15.0 on <span class="colophon-date" title="Sunday 26 October 2025 18:57">Sunday 26 October 2025</span>. Using Julia version 1.12.1.</p></section><footer class="modal-card-foot"></footer></div></div></div></body></html> |
| 98 | +scatter!(x -> model3([x]), -2:0.1f0:2, label="fitted")</code></pre><p>More detail about what exactly the function <code>train!</code> is doing, and how to use rules other than simple <a href="../../../reference/training/optimisers/#Optimisers.Descent"><code>Descent</code></a>, is what the next page in this guide is about: <a href="../../training/training/#man-training">training</a>.</p></article><nav class="docs-footer"><a class="docs-footer-prevpage" href="../overview/">« Fitting a Line</a><a class="docs-footer-nextpage" href="../../training/training/">Training »</a><div class="flexbox-break"></div><p class="footer-message">Powered by <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> and the <a href="https://julialang.org/">Julia Programming Language</a>.</p></nav></div><div class="modal" id="documenter-settings"><div class="modal-background"></div><div class="modal-card"><header class="modal-card-head"><p class="modal-card-title">Settings</p><button class="delete"></button></header><section class="modal-card-body"><p><label class="label">Theme</label><div class="select"><select id="documenter-themepicker"><option value="auto">Automatic (OS)</option><option value="documenter-light">documenter-light</option><option value="documenter-dark">documenter-dark</option><option value="catppuccin-latte">catppuccin-latte</option><option value="catppuccin-frappe">catppuccin-frappe</option><option value="catppuccin-macchiato">catppuccin-macchiato</option><option value="catppuccin-mocha">catppuccin-mocha</option></select></div></p><hr/><p>This document was generated with <a href="https://github.com/JuliaDocs/Documenter.jl">Documenter.jl</a> version 1.15.0 on <span class="colophon-date" title="Sunday 26 October 2025 19:10">Sunday 26 October 2025</span>. Using Julia version 1.12.1.</p></section><footer class="modal-card-foot"></footer></div></div></div></body></html> |
0 commit comments