diff --git a/docs_nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb index a1d7a0526..3f4016f51 100644 --- a/docs_nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -14,14 +14,8 @@ "\n", "Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", "\n", - "Let’s get started!" - ] - }, - { - "cell_type": "markdown", - "id": "1", - "metadata": {}, - "source": [ + "Let’s get started!\n", + "\n", "## 1. Install Flax\n", "\n", "If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):" @@ -263,7 +257,7 @@ "\n", "In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n", "\n", - "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n", + "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n", "\n", "During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics." ] @@ -401,7 +395,7 @@ "\n", "@nnx.jit\n", "def pred_step(model: CNN, batch):\n", - " logits = model(batch['image'])\n", + " logits = model(batch['image'], None)\n", " return logits.argmax(axis=1)" ] }, @@ -441,6 +435,80 @@ " ax.axis('off')" ] }, + { + "cell_type": "markdown", + "id": "65342ab4", + "metadata": {}, + "source": [ + "# 8. Export the model\n", + "\n", + "Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49cace09", + "metadata": {}, + "outputs": [], + "source": [ + "from orbax.export import JaxModule, ExportManager, ServingConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "421309d4", + "metadata": {}, + "outputs": [], + "source": [ + "def exported_predict(model, y):\n", + " return model(y, None)\n", + "\n", + "jax_module = JaxModule(model, exported_predict)" + ] + }, + { + "cell_type": "markdown", + "id": "787136af", + "metadata": {}, + "source": [ + "We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2ad72e", + "metadata": {}, + "outputs": [], + "source": [ + "sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]" + ] + }, + { + "cell_type": "markdown", + "id": "31e9668a", + "metadata": {}, + "source": [ + "Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18cdf9ad", + "metadata": {}, + "outputs": [], + "source": [ + "export_mgr = ExportManager(jax_module, [\n", + " ServingConfig('mnist_server', input_signature=sig)\n", + "])\n", + "\n", + "output_dir='/tmp/mnist_export'\n", + "export_mgr.save(output_dir)" + ] + }, { "cell_type": "markdown", "id": "28", diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index 563ebffec..176d900cd 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -20,8 +20,6 @@ Flax NNX is a Python neural network library built upon [JAX](https://github.com/ Let’s get started! -+++ - ## 1. Install Flax If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook): @@ -141,7 +139,7 @@ nnx.display(optimizer) In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over. -In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. +In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics. @@ -237,7 +235,7 @@ model.eval() # Switch to evaluation mode. @nnx.jit def pred_step(model: CNN, batch): - logits = model(batch['image']) + logits = model(batch['image'], None) return logits.argmax(axis=1) ``` @@ -254,6 +252,38 @@ for i, ax in enumerate(axs.flatten()): ax.axis('off') ``` +# 8. Export the model + +Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method. + +```{code-cell} ipython3 +from orbax.export import JaxModule, ExportManager, ServingConfig +``` + +```{code-cell} ipython3 +def exported_predict(model, y): + return model(y, None) + +jax_module = JaxModule(model, exported_predict) +``` + +We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`. + +```{code-cell} ipython3 +sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)] +``` + +Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class. + +```{code-cell} ipython3 +export_mgr = ExportManager(jax_module, [ + ServingConfig('mnist_server', input_signature=sig) +]) + +output_dir='/tmp/mnist_export' +export_mgr.save(output_dir) +``` + Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset. Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html). diff --git a/examples/mnist/README.md b/examples/mnist/README.md index 2a8169271..386bc5fea 100644 --- a/examples/mnist/README.md +++ b/examples/mnist/README.md @@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default ``` -I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 -I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 +I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25 ``` ### How to run diff --git a/examples/mnist/main.py b/examples/mnist/main.py index 887ecf71e..fe500cb96 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -26,7 +26,7 @@ from ml_collections import config_flags import tensorflow as tf -import train +import train # pylint: disable=g-bad-import-order FLAGS = flags.FLAGS diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 0886a1963..a603e7069 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -20,147 +20,187 @@ # See issue #620. # pytype: disable=wrong-keyword-args +from functools import partial +from typing import Any +from pathlib import Path from absl import logging -from flax import linen as nn +from flax import nnx from flax.metrics import tensorboard -from flax.training import train_state import jax -import jax.numpy as jnp import ml_collections -import numpy as np import optax +import tensorflow as tf import tensorflow_datasets as tfds +tf.random.set_seed(0) # Set the random seed for reproducibility. -class CNN(nn.Module): + +class CNN(nnx.Module): """A simple CNN model.""" - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) + def __init__(self, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) + self.dropout1 = nnx.Dropout(rate=0.025) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) + self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.dropout2 = nnx.Dropout(rate=0.025) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) + + def __call__(self, x, rngs: nnx.Rngs): + x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs)))) + x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs)) + x = self.linear2(x) return x -@jax.jit -def apply_model(state, images, labels): - """Computes gradients, loss and accuracy for a single batch.""" - - def loss_fn(params): - logits = state.apply_fn({'params': params}, images) - one_hot = jax.nn.one_hot(labels, 10) - loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) - return loss, logits - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, logits), grads = grad_fn(state.params) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) - return grads, loss, accuracy +def loss_fn(model: CNN, batch, rngs): + logits = model(batch['image'], rngs) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label'] + ).mean() + return loss, logits -@jax.jit -def update_model(state, grads): - return state.apply_gradients(grads=grads) +@nnx.jit +def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, rngs): + """Train for a single step.""" + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(model, batch, rngs) + metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. + optimizer.update(model, grads) # In-place updates. -def train_epoch(state, train_ds, batch_size, rng): - """Train for a single epoch.""" - train_ds_size = len(train_ds['image']) - steps_per_epoch = train_ds_size // batch_size - perms = jax.random.permutation(rng, len(train_ds['image'])) - perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch - perms = perms.reshape((steps_per_epoch, batch_size)) - epoch_loss = [] - epoch_accuracy = [] +@nnx.jit +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): + loss, logits = loss_fn(model, batch, None) + metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. - for perm in perms: - batch_images = train_ds['image'][perm, ...] - batch_labels = train_ds['label'][perm, ...] - grads, loss, accuracy = apply_model(state, batch_images, batch_labels) - state = update_model(state, grads) - epoch_loss.append(loss) - epoch_accuracy.append(accuracy) - train_loss = np.mean(epoch_loss) - train_accuracy = np.mean(epoch_accuracy) - return state, train_loss, train_accuracy - -def get_datasets(): +def get_datasets( + config: ml_collections.ConfigDict, +) -> tuple[tf.data.Dataset, tf.data.Dataset]: """Load MNIST train and test datasets into memory.""" - ds_builder = tfds.builder('mnist') - ds_builder.download_and_prepare() - train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) - test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) - train_ds['image'] = jnp.float32(train_ds['image']) / 255.0 - test_ds['image'] = jnp.float32(test_ds['image']) / 255.0 - return train_ds, test_ds + batch_size = config.batch_size + train_ds: tf.data.Dataset = tfds.load('mnist', split='train') + test_ds: tf.data.Dataset = tfds.load('mnist', split='test') + + train_ds = train_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } + ) # normalize train set + test_ds = test_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } + ) # normalize the test set. + + # Create a shuffled dataset by allocating a buffer size of 1024 to randomly + # draw elements from. + train_ds = train_ds.shuffle(1024) + # Group into batches of `batch_size` and skip incomplete batches, prefetch the + # next sample to improve latency. + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) + # Group into batches of `batch_size` and skip incomplete batches, prefetch the + # next sample to improve latency. + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) - -def create_train_state(rng, config): - """Creates initial `TrainState`.""" - cnn = CNN() - params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] - tx = optax.sgd(config.learning_rate, config.momentum) - return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) + return train_ds, test_ds -def train_and_evaluate( - config: ml_collections.ConfigDict, workdir: str -) -> train_state.TrainState: +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. - workdir: Directory where the tensorboard summaries are written to. - - Returns: - The train state (which includes the `.params`). + workdir: Directory path to store metrics. """ - train_ds, test_ds = get_datasets() - rng = jax.random.key(0) + train_ds, test_ds = get_datasets(config) + + # Instantiate the model. + model = CNN(rngs=nnx.Rngs(0)) + + learning_rate = config.learning_rate + momentum = config.momentum summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) - rng, init_rng = jax.random.split(rng) - state = create_train_state(init_rng, config) + optimizer = nnx.Optimizer( + model, optax.sgd(learning_rate, momentum), wrt=nnx.Param + ) + metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average('loss'), + ) + rngs = nnx.Rngs(0) for epoch in range(1, config.num_epochs + 1): - rng, input_rng = jax.random.split(rng) - state, train_loss, train_accuracy = train_epoch( - state, train_ds, config.batch_size, input_rng - ) - _, test_loss, test_accuracy = apply_model( - state, test_ds['image'], test_ds['label'] - ) - - logging.info( + # Run the optimization for one step and make a stateful update to the + # following: + # - The train state's model parameters + # - The optimizer state + # - The training loss and accuracy batch metrics + model.train() # Switch to train mode + + for batch in train_ds.as_numpy_iterator(): + train_step(model, optimizer, metrics, batch, rngs) + # Compute the training metrics. + train_metrics = metrics.compute() + metrics.reset() # Reset the metrics for the test set. + + # Compute the metrics on the test set after each training epoch. + model.eval() # Switch to eval mode + for batch in test_ds.as_numpy_iterator(): + eval_step(model, metrics, batch) + + # Compute the eval metrics. + eval_metrics = metrics.compute() + metrics.reset() # Reset the metrics for the next training epoch. + + logging.info( # pylint: disable=logging-not-lazy 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,' ' test_accuracy: %.2f' % ( epoch, - train_loss, - train_accuracy * 100, - test_loss, - test_accuracy * 100, + train_metrics['loss'], + train_metrics['accuracy'] * 100, + eval_metrics['loss'], + eval_metrics['accuracy'] * 100, ) ) - summary_writer.scalar('train_loss', train_loss, epoch) - summary_writer.scalar('train_accuracy', train_accuracy, epoch) - summary_writer.scalar('test_loss', test_loss, epoch) - summary_writer.scalar('test_accuracy', test_accuracy, epoch) + # Write the metrics to TensorBoard. + summary_writer.scalar('train_loss', train_metrics['loss'], epoch) + summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) + summary_writer.scalar('test_loss', eval_metrics['loss'], epoch) + summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch) summary_writer.flush() - return state + + # Export the model to a SavedModel directory. + from orbax.export import JaxModule, ExportManager, ServingConfig + + def exported_predict(model, y): + return model(y, None) + + model.eval() + jax_module = JaxModule(model, exported_predict) + sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)] + export_mgr = ExportManager(jax_module, [ + ServingConfig('mnist_server', input_signature=sig) + ]) + + output_dir= Path(workdir) / 'mnist_export' + export_mgr.save(str(output_dir)) diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py index fecc2f36c..ae1993f72 100644 --- a/examples/mnist/train_test.py +++ b/examples/mnist/train_test.py @@ -16,9 +16,11 @@ import pathlib import tempfile +import sys from absl.testing import absltest import jax +import flax.nnx as nnx from jax import numpy as jnp import numpy as np import tensorflow as tf @@ -36,23 +38,19 @@ class TrainTest(absltest.TestCase): def setUp(self): super().setUp() + if sys.version_info < (3, 13): + self.skipTest('Tensorflow 2.20 required for this test, which conflicts with tensorflow_text.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], "GPU") def test_cnn(self): """Tests CNN module used as the trainable model.""" - rng = jax.random.key(0) - inputs = jnp.ones((1, 28, 28, 3), jnp.float32) - output, variables = train.CNN().init_with_output(rng, inputs) + inputs = jnp.ones((1, 28, 28, 1), jnp.float32) + cnn = train.CNN(nnx.Rngs(0)) + cnn.eval() + output = cnn(inputs, None) self.assertEqual((1, 10), output.shape) - self.assertEqual( - CNN_PARAMS, - sum( - np.prod(arr.shape) - for arr in jax.tree_util.tree_leaves(variables["params"]) - ), - ) def test_train_and_evaluate(self): """Tests training and evaluation code by running a single step.""" diff --git a/pyproject.toml b/pyproject.toml index c66a75904..d3aa27247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "typing_extensions>=4.2", "PyYAML>=5.4.1", "treescope>=0.1.7", + "orbax-export>=0.0.8", ] classifiers = [ "Development Status :: 3 - Alpha",