Skip to content

Commit 5c032b4

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:mnist] Update mnist example to use NNX.
PiperOrigin-RevId: 815862166
1 parent 75fd8fa commit 5c032b4

File tree

5 files changed

+678
-652
lines changed

5 files changed

+678
-652
lines changed

docs_nnx/mnist_tutorial.ipynb

Lines changed: 549 additions & 549 deletions
Large diffs are not rendered by default.

docs_nnx/mnist_tutorial.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b
158158
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
159159
(loss, logits), grads = grad_fn(model, batch)
160160
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
161-
optimizer.update(grads) # In-place updates.
161+
optimizer.update(model, grads) # In-place updates.
162162
163163
@nnx.jit
164164
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):

examples/mnist/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn
2020
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default
2121

2222
```
23-
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
24-
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
23+
I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25
2524
```
2625

2726
### How to run

examples/mnist/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ml_collections import config_flags
2727
import tensorflow as tf
2828

29-
import train
29+
import train # pylint: disable=g-bad-import-order
3030

3131

3232
FLAGS = flags.FLAGS

examples/mnist/train.py

Lines changed: 126 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -20,147 +20,174 @@
2020

2121
# See issue #620.
2222
# pytype: disable=wrong-keyword-args
23+
import functools
24+
from typing import Any
2325

2426
from absl import logging
25-
from flax import linen as nn
27+
from flax import nnx
2628
from flax.metrics import tensorboard
27-
from flax.training import train_state
2829
import jax
29-
import jax.numpy as jnp
3030
import ml_collections
31-
import numpy as np
3231
import optax
32+
import tensorflow as tf
3333
import tensorflow_datasets as tfds
3434

35+
tf.random.set_seed(0) # Set the random seed for reproducibility.
3536

36-
class CNN(nn.Module):
37+
38+
class CNN(nnx.Module):
3739
"""A simple CNN model."""
3840

39-
@nn.compact
41+
def __init__(self, *, rngs: nnx.Rngs):
42+
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
43+
self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)
44+
self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs)
45+
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
46+
self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)
47+
self.avg_pool = functools.partial(
48+
nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)
49+
)
50+
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
51+
self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs)
52+
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
53+
4054
def __call__(self, x):
41-
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
42-
x = nn.relu(x)
43-
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
44-
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
45-
x = nn.relu(x)
46-
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
47-
x = x.reshape((x.shape[0], -1)) # flatten
48-
x = nn.Dense(features=256)(x)
49-
x = nn.relu(x)
50-
x = nn.Dense(features=10)(x)
55+
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x)))))
56+
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
57+
x = x.reshape(x.shape[0], -1) # flatten
58+
x = nnx.relu(self.dropout2(self.linear1(x)))
59+
x = self.linear2(x)
5160
return x
5261

5362

54-
@jax.jit
55-
def apply_model(state, images, labels):
56-
"""Computes gradients, loss and accuracy for a single batch."""
57-
58-
def loss_fn(params):
59-
logits = state.apply_fn({'params': params}, images)
60-
one_hot = jax.nn.one_hot(labels, 10)
61-
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
62-
return loss, logits
63-
64-
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
65-
(loss, logits), grads = grad_fn(state.params)
66-
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
67-
return grads, loss, accuracy
63+
def loss_fn(model: CNN, batch) -> tuple[jax.Array, Any]:
64+
logits = model(batch['image'])
65+
loss = optax.softmax_cross_entropy_with_integer_labels(
66+
logits=logits, labels=batch['label']
67+
).mean()
68+
return loss, logits
6869

6970

70-
@jax.jit
71-
def update_model(state, grads):
72-
return state.apply_gradients(grads=grads)
71+
@nnx.jit
72+
def train_step(
73+
model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
74+
) -> None:
75+
"""Train for a single step."""
76+
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
77+
(loss, logits), grads = grad_fn(model, batch)
78+
metrics.update(
79+
loss=loss, logits=logits, labels=batch['label']
80+
) # In-place updates.
81+
optimizer.update(model, grads) # In-place updates.
7382

7483

75-
def train_epoch(state, train_ds, batch_size, rng):
76-
"""Train for a single epoch."""
77-
train_ds_size = len(train_ds['image'])
78-
steps_per_epoch = train_ds_size // batch_size
84+
@nnx.jit
85+
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch) -> None:
86+
loss, logits = loss_fn(model, batch)
87+
metrics.update(
88+
loss=loss, logits=logits, labels=batch['label']
89+
) # In-place updates.
7990

80-
perms = jax.random.permutation(rng, len(train_ds['image']))
81-
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
82-
perms = perms.reshape((steps_per_epoch, batch_size))
8391

84-
epoch_loss = []
85-
epoch_accuracy = []
86-
87-
for perm in perms:
88-
batch_images = train_ds['image'][perm, ...]
89-
batch_labels = train_ds['label'][perm, ...]
90-
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
91-
state = update_model(state, grads)
92-
epoch_loss.append(loss)
93-
epoch_accuracy.append(accuracy)
94-
train_loss = np.mean(epoch_loss)
95-
train_accuracy = np.mean(epoch_accuracy)
96-
return state, train_loss, train_accuracy
97-
98-
99-
def get_datasets():
92+
def get_datasets(
93+
config: ml_collections.ConfigDict,
94+
) -> tuple[tf.data.Dataset, tf.data.Dataset]:
10095
"""Load MNIST train and test datasets into memory."""
101-
ds_builder = tfds.builder('mnist')
102-
ds_builder.download_and_prepare()
103-
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
104-
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
105-
train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
106-
test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
107-
return train_ds, test_ds
96+
batch_size = config.batch_size
97+
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
98+
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
99+
100+
train_ds = train_ds.map(
101+
lambda sample: {
102+
'image': tf.cast(sample['image'], tf.float32) / 255,
103+
'label': sample['label'],
104+
}
105+
) # normalize train set
106+
test_ds = test_ds.map(
107+
lambda sample: {
108+
'image': tf.cast(sample['image'], tf.float32) / 255,
109+
'label': sample['label'],
110+
}
111+
) # normalize the test set.
112+
113+
# Create a shuffled dataset by allocating a buffer size of 1024 to randomly
114+
# draw elements from.
115+
train_ds = train_ds.shuffle(1024)
116+
# Group into batches of `batch_size` and skip incomplete batches, prefetch the
117+
# next sample to improve latency.
118+
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
119+
# Group into batches of `batch_size` and skip incomplete batches, prefetch the
120+
# next sample to improve latency.
121+
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
108122

109-
110-
def create_train_state(rng, config):
111-
"""Creates initial `TrainState`."""
112-
cnn = CNN()
113-
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
114-
tx = optax.sgd(config.learning_rate, config.momentum)
115-
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
123+
return train_ds, test_ds
116124

117125

118-
def train_and_evaluate(
119-
config: ml_collections.ConfigDict, workdir: str
120-
) -> train_state.TrainState:
126+
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None:
121127
"""Execute model training and evaluation loop.
122128
123129
Args:
124130
config: Hyperparameter configuration for training and evaluation.
125-
workdir: Directory where the tensorboard summaries are written to.
126-
127-
Returns:
128-
The train state (which includes the `.params`).
131+
workdir: Directory path to store metrics.
129132
"""
130-
train_ds, test_ds = get_datasets()
131-
rng = jax.random.key(0)
133+
train_ds, test_ds = get_datasets(config)
134+
135+
# Instantiate the model.
136+
model = CNN(rngs=nnx.Rngs(0))
137+
138+
learning_rate = config.learning_rate
139+
momentum = config.momentum
132140

133141
summary_writer = tensorboard.SummaryWriter(workdir)
134142
summary_writer.hparams(dict(config))
135143

136-
rng, init_rng = jax.random.split(rng)
137-
state = create_train_state(init_rng, config)
144+
optimizer = nnx.Optimizer(
145+
model, optax.sgd(learning_rate, momentum), wrt=nnx.Param
146+
)
147+
metrics = nnx.MultiMetric(
148+
accuracy=nnx.metrics.Accuracy(),
149+
loss=nnx.metrics.Average('loss'),
150+
)
138151

139152
for epoch in range(1, config.num_epochs + 1):
140-
rng, input_rng = jax.random.split(rng)
141-
state, train_loss, train_accuracy = train_epoch(
142-
state, train_ds, config.batch_size, input_rng
143-
)
144-
_, test_loss, test_accuracy = apply_model(
145-
state, test_ds['image'], test_ds['label']
146-
)
147-
148-
logging.info(
153+
# Run the optimization for one step and make a stateful update to the
154+
# following:
155+
# - The train state's model parameters
156+
# - The optimizer state
157+
# - The training loss and accuracy batch metrics
158+
model.train() # Switch to train mode
159+
160+
for batch in train_ds.as_numpy_iterator():
161+
train_step(model, optimizer, metrics, batch)
162+
# Compute the training metrics.
163+
train_metrics = metrics.compute()
164+
metrics.reset() # Reset the metrics for the test set.
165+
166+
# Compute the metrics on the test set after each training epoch.
167+
model.eval() # Switch to eval mode
168+
for batch in test_ds.as_numpy_iterator():
169+
eval_step(model, metrics, batch)
170+
171+
# Compute the eval metrics.
172+
eval_metrics = metrics.compute()
173+
metrics.reset() # Reset the metrics for the next training epoch.
174+
175+
logging.info( # pylint: disable=logging-not-lazy
149176
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,'
150177
' test_accuracy: %.2f'
151178
% (
152179
epoch,
153-
train_loss,
154-
train_accuracy * 100,
155-
test_loss,
156-
test_accuracy * 100,
180+
train_metrics['loss'],
181+
train_metrics['accuracy'] * 100,
182+
eval_metrics['loss'],
183+
eval_metrics['accuracy'] * 100,
157184
)
158185
)
159186

160-
summary_writer.scalar('train_loss', train_loss, epoch)
161-
summary_writer.scalar('train_accuracy', train_accuracy, epoch)
162-
summary_writer.scalar('test_loss', test_loss, epoch)
163-
summary_writer.scalar('test_accuracy', test_accuracy, epoch)
187+
# Write the metrics to TensorBoard.
188+
summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
189+
summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
190+
summary_writer.scalar('test_loss', eval_metrics['loss'], epoch)
191+
summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch)
164192

165193
summary_writer.flush()
166-
return state

0 commit comments

Comments
 (0)