Skip to content

Commit 168aa3b

Browse files
author
Flax Authors
committed
Merge pull request #4284 from 8bitmp3:lint-glossary
PiperOrigin-RevId: 686677676
2 parents 3bf732c + 46ac862 commit 168aa3b

File tree

4 files changed

+47
-54
lines changed

4 files changed

+47
-54
lines changed

docs_nnx/glossary.rst

Lines changed: 0 additions & 50 deletions
This file was deleted.

docs_nnx/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Learn more
182182

183183
.. card:: :material-regular:`import_contacts;2em` Glossary
184184
:class-card: sd-text-black sd-bg-light
185-
:link: glossary.html
185+
:link: nnx_glossary.html
186186

187187

188188
----
@@ -196,7 +196,7 @@ Learn more
196196
why
197197
guides/index
198198
examples/index
199-
glossary
199+
nnx_glossary
200200
The Flax philosophy <philosophyhttps://flax.readthedocs.io/en/latest/philosophy.html>
201201
How to contribute <https://flax.readthedocs.io/en/latest/contributing.html>
202202
api_reference/index

docs_nnx/nnx_glossary.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
*****************
2+
Flax NNX glossary
3+
*****************
4+
5+
For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/latest/glossary.html>`__.
6+
7+
.. glossary::
8+
9+
Filter
10+
A way to extract only certain :term:`nnx.Variable<Variable>` objects out of a Flax NNX :term:`Module<Module>` (``nnx.Module``). This is usually done by calling :meth:`nnx.split <flax.nnx.split>` upon the :class:`nnx.Module<flax.nnx.Module>`. Refer to the `Filter guide <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to learn more.
11+
12+
Folding in
13+
In Flax, `folding in <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.fold_in.html>`__ means generating a new `JAX pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with `jax.random.split <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html>`__, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the `Randomness/PRNG guide <https://flax.readthedocs.io/en/latest/guides/randomness.html>`__.
14+
15+
GraphDef
16+
:class:`nnx.GraphDef<flax.nnx.GraphDef>` is a class that represents all the static, stateless, and Pythonic parts of a Flax :term:`Module<Module>` (:class:`nnx.Module<flax.nnx.Module>`).
17+
18+
Merge
19+
Refer to :term:`Split and merge<Split and merge>`.
20+
21+
Module
22+
:class:`nnx.Module <flax.nnx.Module>` is a dataclass that enables defining and initializing parameters in a referentially-transparent form. It is responsible for storing and updating :term:`Variable<Variable> objects and parameters within itself.
23+
24+
Params / parameters
25+
:class:`nnx.Param <flax.nnx.Param>` is a particular subclass of :class:`nnx.Variable <flax.nnx.Variable>` that generally contains the trainable weights.
26+
27+
PRNG states
28+
A Flax :class:`nnx.Module <flax.nnx.Module>` can keep a reference of a `pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ state object :class:`nnx.Rngs <flax.nnx.Rngs>` that can generate new `JAX PRNG <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ keys. These keys are used to generate random JAX arrays through `JAX's functional PRNGs <https://jax.readthedocs.io/en/latest/random-numbers.html>`__.
29+
You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks).
30+
Refer to the Flax `Randomness/PRNG guide <https://flax.readthedocs.io/en/latest/guides/randomness.html>`__
31+
for more details.
32+
33+
Split and merge
34+
:meth:`nnx.split <flax.nnx.split>` is a way to represent an :class:`nnx.Module <flax.nnx.Module>` by two parts: 1) a static Flax NNX :term:`GraphDef <GraphDef>` that captures its Pythonic static information; and 2) one or more :term:`Variable state(s)<Variable state>` that capture its `JAX arrays <https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array>`__ (``jax.Array``) in the form of `JAX pytrees <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__. They can be merged back to the original ``nnx.Module`` using :meth:`nnx.merge <flax.nnx.merge>`.
35+
36+
Transformation
37+
A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the function that is being transformed to take the Flax NNX :term:`Module<Module>` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ is :meth:`nnx.jit <flax.nnx.jit>`. Check out the `Flax NNX transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ to learn more.
38+
39+
Variable
40+
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.
41+
42+
Variable state
43+
:class:`nnx.VariableState <flax.nnx.VariableState>` is a purely functional `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of all the :term:`Variables<Variable>` inside a :term:`Module<Module>`. Since it is pure, it can be an input or output of a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split <flax.nnx.split>` on the :class:`nnx.Module <flax.nnx.Module>`. (Refer to :term:`splitting<Split and merge>` and :term:`Module<Module>` to learn more.)

docs_nnx/why.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Why Flax NNX?
44
In 2020, the Flax team released the Flax Linen API to support modeling research on JAX, with a focus on scaling
55
and performance. We have learned a lot from users since then. The team introduced certain ideas that have proven to be beneficial to users, such as:
66

7-
* Organizing variables into `collections <https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections>`_.
8-
* Automatic and efficient `pseudorandom number generator (PRNG) management <https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences>`_.
7+
* Organizing variables into `collections <https://flax.readthedocs.io/en/latest/nnx_glossary.html#term-Variable-collections>`_.
8+
* Automatic and efficient `pseudorandom number generator (PRNG) management <https://flax.readthedocs.io/en/latest/nnx_glossary.html#term-RNG-sequences>`_.
99
* `Variable metadata <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning>`_
1010
for `Single Program Multi Data (SPMD) <https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD>`_ annotations, optimizer metadata, and other use cases.
1111

0 commit comments

Comments
 (0)