diff --git a/flax/nnx/README.md b/flax/nnx/README.md index 7a39eb132..14829c0fa 100644 --- a/flax/nnx/README.md +++ b/flax/nnx/README.md @@ -64,10 +64,10 @@ pip install git+https://github.com/google/flax.git ### Examples -* [LM1B](https://github.com/google/flax/tree/main/flax/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. +* [LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples -* [Basic Example](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. -* [Using the Functional API](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. -* [Scan over layers](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Basic Example](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. +* [Using the Functional API](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Training a VAE](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. +* [Scan over layers](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`.