|
20 | 20 |
|
21 | 21 | # See issue #620. |
22 | 22 | # pytype: disable=wrong-keyword-args |
| 23 | +import functools |
| 24 | +from typing import Any |
23 | 25 |
|
24 | 26 | from absl import logging |
25 | | -from flax import linen as nn |
| 27 | +from flax import nnx |
26 | 28 | from flax.metrics import tensorboard |
27 | | -from flax.training import train_state |
28 | 29 | import jax |
29 | | -import jax.numpy as jnp |
30 | 30 | import ml_collections |
31 | | -import numpy as np |
32 | 31 | import optax |
| 32 | +import tensorflow as tf |
33 | 33 | import tensorflow_datasets as tfds |
34 | 34 |
|
| 35 | +tf.random.set_seed(0) # Set the random seed for reproducibility. |
35 | 36 |
|
36 | | -class CNN(nn.Module): |
| 37 | + |
| 38 | +class CNN(nnx.Module): |
37 | 39 | """A simple CNN model.""" |
38 | 40 |
|
39 | | - @nn.compact |
| 41 | + def __init__(self, *, rngs: nnx.Rngs): |
| 42 | + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) |
| 43 | + self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) |
| 44 | + self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs) |
| 45 | + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) |
| 46 | + self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) |
| 47 | + self.avg_pool = functools.partial( |
| 48 | + nnx.avg_pool, window_shape=(2, 2), strides=(2, 2) |
| 49 | + ) |
| 50 | + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) |
| 51 | + self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs) |
| 52 | + self.linear2 = nnx.Linear(256, 10, rngs=rngs) |
| 53 | + |
40 | 54 | def __call__(self, x): |
41 | | - x = nn.Conv(features=32, kernel_size=(3, 3))(x) |
42 | | - x = nn.relu(x) |
43 | | - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
44 | | - x = nn.Conv(features=64, kernel_size=(3, 3))(x) |
45 | | - x = nn.relu(x) |
46 | | - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
47 | | - x = x.reshape((x.shape[0], -1)) # flatten |
48 | | - x = nn.Dense(features=256)(x) |
49 | | - x = nn.relu(x) |
50 | | - x = nn.Dense(features=10)(x) |
| 55 | + x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x))))) |
| 56 | + x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) |
| 57 | + x = x.reshape(x.shape[0], -1) # flatten |
| 58 | + x = nnx.relu(self.dropout2(self.linear1(x))) |
| 59 | + x = self.linear2(x) |
51 | 60 | return x |
52 | 61 |
|
53 | 62 |
|
54 | | -@jax.jit |
55 | | -def apply_model(state, images, labels): |
56 | | - """Computes gradients, loss and accuracy for a single batch.""" |
57 | | - |
58 | | - def loss_fn(params): |
59 | | - logits = state.apply_fn({'params': params}, images) |
60 | | - one_hot = jax.nn.one_hot(labels, 10) |
61 | | - loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) |
62 | | - return loss, logits |
63 | | - |
64 | | - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
65 | | - (loss, logits), grads = grad_fn(state.params) |
66 | | - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
67 | | - return grads, loss, accuracy |
| 63 | +def loss_fn(model: CNN, batch) -> tuple[jax.Array, Any]: |
| 64 | + logits = model(batch['image']) |
| 65 | + loss = optax.softmax_cross_entropy_with_integer_labels( |
| 66 | + logits=logits, labels=batch['label'] |
| 67 | + ).mean() |
| 68 | + return loss, logits |
68 | 69 |
|
69 | 70 |
|
70 | | -@jax.jit |
71 | | -def update_model(state, grads): |
72 | | - return state.apply_gradients(grads=grads) |
| 71 | +@nnx.jit |
| 72 | +def train_step( |
| 73 | + model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch |
| 74 | +) -> None: |
| 75 | + """Train for a single step.""" |
| 76 | + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) |
| 77 | + (loss, logits), grads = grad_fn(model, batch) |
| 78 | + metrics.update( |
| 79 | + loss=loss, logits=logits, labels=batch['label'] |
| 80 | + ) # In-place updates. |
| 81 | + optimizer.update(model, grads) # In-place updates. |
73 | 82 |
|
74 | 83 |
|
75 | | -def train_epoch(state, train_ds, batch_size, rng): |
76 | | - """Train for a single epoch.""" |
77 | | - train_ds_size = len(train_ds['image']) |
78 | | - steps_per_epoch = train_ds_size // batch_size |
| 84 | +@nnx.jit |
| 85 | +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch) -> None: |
| 86 | + loss, logits = loss_fn(model, batch) |
| 87 | + metrics.update( |
| 88 | + loss=loss, logits=logits, labels=batch['label'] |
| 89 | + ) # In-place updates. |
79 | 90 |
|
80 | | - perms = jax.random.permutation(rng, len(train_ds['image'])) |
81 | | - perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch |
82 | | - perms = perms.reshape((steps_per_epoch, batch_size)) |
83 | 91 |
|
84 | | - epoch_loss = [] |
85 | | - epoch_accuracy = [] |
86 | | - |
87 | | - for perm in perms: |
88 | | - batch_images = train_ds['image'][perm, ...] |
89 | | - batch_labels = train_ds['label'][perm, ...] |
90 | | - grads, loss, accuracy = apply_model(state, batch_images, batch_labels) |
91 | | - state = update_model(state, grads) |
92 | | - epoch_loss.append(loss) |
93 | | - epoch_accuracy.append(accuracy) |
94 | | - train_loss = np.mean(epoch_loss) |
95 | | - train_accuracy = np.mean(epoch_accuracy) |
96 | | - return state, train_loss, train_accuracy |
97 | | - |
98 | | - |
99 | | -def get_datasets(): |
| 92 | +def get_datasets( |
| 93 | + config: ml_collections.ConfigDict, |
| 94 | +) -> tuple[tf.data.Dataset, tf.data.Dataset]: |
100 | 95 | """Load MNIST train and test datasets into memory.""" |
101 | | - ds_builder = tfds.builder('mnist') |
102 | | - ds_builder.download_and_prepare() |
103 | | - train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) |
104 | | - test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) |
105 | | - train_ds['image'] = jnp.float32(train_ds['image']) / 255.0 |
106 | | - test_ds['image'] = jnp.float32(test_ds['image']) / 255.0 |
107 | | - return train_ds, test_ds |
| 96 | + batch_size = config.batch_size |
| 97 | + train_ds: tf.data.Dataset = tfds.load('mnist', split='train') |
| 98 | + test_ds: tf.data.Dataset = tfds.load('mnist', split='test') |
| 99 | + |
| 100 | + train_ds = train_ds.map( |
| 101 | + lambda sample: { |
| 102 | + 'image': tf.cast(sample['image'], tf.float32) / 255, |
| 103 | + 'label': sample['label'], |
| 104 | + } |
| 105 | + ) # normalize train set |
| 106 | + test_ds = test_ds.map( |
| 107 | + lambda sample: { |
| 108 | + 'image': tf.cast(sample['image'], tf.float32) / 255, |
| 109 | + 'label': sample['label'], |
| 110 | + } |
| 111 | + ) # normalize the test set. |
| 112 | + |
| 113 | + # Create a shuffled dataset by allocating a buffer size of 1024 to randomly |
| 114 | + # draw elements from. |
| 115 | + train_ds = train_ds.shuffle(1024) |
| 116 | + # Group into batches of `batch_size` and skip incomplete batches, prefetch the |
| 117 | + # next sample to improve latency. |
| 118 | + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) |
| 119 | + # Group into batches of `batch_size` and skip incomplete batches, prefetch the |
| 120 | + # next sample to improve latency. |
| 121 | + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) |
108 | 122 |
|
109 | | - |
110 | | -def create_train_state(rng, config): |
111 | | - """Creates initial `TrainState`.""" |
112 | | - cnn = CNN() |
113 | | - params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] |
114 | | - tx = optax.sgd(config.learning_rate, config.momentum) |
115 | | - return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) |
| 123 | + return train_ds, test_ds |
116 | 124 |
|
117 | 125 |
|
118 | | -def train_and_evaluate( |
119 | | - config: ml_collections.ConfigDict, workdir: str |
120 | | -) -> train_state.TrainState: |
| 126 | +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None: |
121 | 127 | """Execute model training and evaluation loop. |
122 | 128 |
|
123 | 129 | Args: |
124 | 130 | config: Hyperparameter configuration for training and evaluation. |
125 | | - workdir: Directory where the tensorboard summaries are written to. |
126 | | -
|
127 | | - Returns: |
128 | | - The train state (which includes the `.params`). |
| 131 | + workdir: Directory path to store metrics. |
129 | 132 | """ |
130 | | - train_ds, test_ds = get_datasets() |
131 | | - rng = jax.random.key(0) |
| 133 | + train_ds, test_ds = get_datasets(config) |
| 134 | + |
| 135 | + # Instantiate the model. |
| 136 | + model = CNN(rngs=nnx.Rngs(0)) |
| 137 | + |
| 138 | + learning_rate = config.learning_rate |
| 139 | + momentum = config.momentum |
132 | 140 |
|
133 | 141 | summary_writer = tensorboard.SummaryWriter(workdir) |
134 | 142 | summary_writer.hparams(dict(config)) |
135 | 143 |
|
136 | | - rng, init_rng = jax.random.split(rng) |
137 | | - state = create_train_state(init_rng, config) |
| 144 | + optimizer = nnx.Optimizer( |
| 145 | + model, optax.sgd(learning_rate, momentum), wrt=nnx.Param |
| 146 | + ) |
| 147 | + metrics = nnx.MultiMetric( |
| 148 | + accuracy=nnx.metrics.Accuracy(), |
| 149 | + loss=nnx.metrics.Average('loss'), |
| 150 | + ) |
138 | 151 |
|
139 | 152 | for epoch in range(1, config.num_epochs + 1): |
140 | | - rng, input_rng = jax.random.split(rng) |
141 | | - state, train_loss, train_accuracy = train_epoch( |
142 | | - state, train_ds, config.batch_size, input_rng |
143 | | - ) |
144 | | - _, test_loss, test_accuracy = apply_model( |
145 | | - state, test_ds['image'], test_ds['label'] |
146 | | - ) |
147 | | - |
148 | | - logging.info( |
| 153 | + # Run the optimization for one step and make a stateful update to the |
| 154 | + # following: |
| 155 | + # - The train state's model parameters |
| 156 | + # - The optimizer state |
| 157 | + # - The training loss and accuracy batch metrics |
| 158 | + model.train() # Switch to train mode |
| 159 | + |
| 160 | + for batch in train_ds.as_numpy_iterator(): |
| 161 | + train_step(model, optimizer, metrics, batch) |
| 162 | + # Compute the training metrics. |
| 163 | + train_metrics = metrics.compute() |
| 164 | + metrics.reset() # Reset the metrics for the test set. |
| 165 | + |
| 166 | + # Compute the metrics on the test set after each training epoch. |
| 167 | + model.eval() # Switch to eval mode |
| 168 | + for batch in test_ds.as_numpy_iterator(): |
| 169 | + eval_step(model, metrics, batch) |
| 170 | + |
| 171 | + # Compute the eval metrics. |
| 172 | + eval_metrics = metrics.compute() |
| 173 | + metrics.reset() # Reset the metrics for the next training epoch. |
| 174 | + |
| 175 | + logging.info( # pylint: disable=logging-not-lazy |
149 | 176 | 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,' |
150 | 177 | ' test_accuracy: %.2f' |
151 | 178 | % ( |
152 | 179 | epoch, |
153 | | - train_loss, |
154 | | - train_accuracy * 100, |
155 | | - test_loss, |
156 | | - test_accuracy * 100, |
| 180 | + train_metrics['loss'], |
| 181 | + train_metrics['accuracy'] * 100, |
| 182 | + eval_metrics['loss'], |
| 183 | + eval_metrics['accuracy'] * 100, |
157 | 184 | ) |
158 | 185 | ) |
159 | 186 |
|
160 | | - summary_writer.scalar('train_loss', train_loss, epoch) |
161 | | - summary_writer.scalar('train_accuracy', train_accuracy, epoch) |
162 | | - summary_writer.scalar('test_loss', test_loss, epoch) |
163 | | - summary_writer.scalar('test_accuracy', test_accuracy, epoch) |
| 187 | + # Write the metrics to TensorBoard. |
| 188 | + summary_writer.scalar('train_loss', train_metrics['loss'], epoch) |
| 189 | + summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) |
| 190 | + summary_writer.scalar('test_loss', eval_metrics['loss'], epoch) |
| 191 | + summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch) |
164 | 192 |
|
165 | 193 | summary_writer.flush() |
166 | | - return state |
|
0 commit comments