diff --git a/docs_nnx/guides/Optimization Cookbook.ipynb b/docs_nnx/guides/Optimization Cookbook.ipynb new file mode 100644 index 000000000..bd1228348 --- /dev/null +++ b/docs_nnx/guides/Optimization Cookbook.ipynb @@ -0,0 +1,994 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3b8f2147-e42f-44ec-974b-a6efb7dae7a0", + "metadata": {}, + "source": [ + "# A Flax Optimization Cookbook" + ] + }, + { + "cell_type": "markdown", + "id": "ca85ff65-539b-4529-b7cc-03556a484e5f", + "metadata": {}, + "source": [ + "This notebook goes through some common problems in nontrivial training loops for flax models. For clarity, all sections below will be training the following toy model. Both raw Jax and Flax code are shown for comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c1f9f773-822f-425a-86e2-487405051362", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "from flax import nnx\n", + "jax.config.update('jax_num_cpu_devices', 8)\n", + "import jax.numpy as jnp\n", + "from jax import tree\n", + "import optax\n", + "import itertools as it\n", + "import functools as ft\n", + "from collections import namedtuple" + ] + }, + { + "cell_type": "markdown", + "id": "21df7ab0-7fa7-474f-aefd-e4fa8ed23965", + "metadata": {}, + "source": [ + "Here is the NNX version:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "82b748b4-abd8-429c-9b95-ca6f99d0d5d1", + "metadata": {}, + "outputs": [], + "source": [ + "def nnx_model(rngs):\n", + " return nnx.Sequential(nnx.Linear(2,8, rngs=rngs), nnx.Linear(8,8, rngs=rngs))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "66cb1438-eaab-493f-9ea6-64939ce4de89", + "metadata": {}, + "outputs": [], + "source": [ + "model = nnx_model(nnx.Rngs(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "21db7113-f76e-4645-9726-75e61f9e946e", + "metadata": {}, + "outputs": [], + "source": [ + "def nnx_loss_fn(params, x, y):\n", + " return jnp.sum((y - model(params, x))**2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ecc46d9c-0dec-48d2-bedb-c78e227ffad3", + "metadata": {}, + "outputs": [], + "source": [ + "def nnx_loss_fn(model, x, y):\n", + " return jnp.sum((model(x) - y) ** 2)" + ] + }, + { + "cell_type": "markdown", + "id": "1034c378-751a-431e-8d36-a3959c8396a8", + "metadata": {}, + "source": [ + "And here is its equivalent raw Jax representation:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5699e265-4a92-4bba-8a3d-6ba2dc3bd4a8", + "metadata": {}, + "outputs": [], + "source": [ + "keys = map(ft.partial(jax.random.fold_in, jax.random.key(0)), it.count())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dc08d188-8e48-4d3e-914c-1b69076be624", + "metadata": {}, + "outputs": [], + "source": [ + "param_init = jax.nn.initializers.lecun_normal()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "cfe6796e-01ce-4b98-9f31-b644ad3ee886", + "metadata": {}, + "outputs": [], + "source": [ + "def make_linear(size, keys):\n", + " return {\n", + " 'w': param_init(next(keys), size),\n", + " 'b': jnp.zeros(size[1])\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b315fd6b-9cad-4a8c-81bb-029411494164", + "metadata": {}, + "outputs": [], + "source": [ + "def jax_params(keys):\n", + " return [make_linear((2, 8), keys), make_linear((8, 8), keys)]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "6a418818-ef3a-476f-be39-6f99ae2dd678", + "metadata": {}, + "outputs": [], + "source": [ + "def jax_model(params, x):\n", + " for p in params:\n", + " x = x @ p['w'] + p['b']\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1971ac56-38ae-4bc0-bbd2-8ec825676438", + "metadata": {}, + "outputs": [], + "source": [ + "def jax_loss_fn(params, x, y):\n", + " return jnp.sum((y - jax_model(params, x))**2)" + ] + }, + { + "cell_type": "markdown", + "id": "ee0fcb67-65f2-4952-ae83-d0e058e3c259", + "metadata": {}, + "source": [ + "We'll operate on the following fake data:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "1931ee46-8f60-462a-876b-e53911f94484", + "metadata": {}, + "outputs": [], + "source": [ + "x = jax.random.normal(next(keys), (32, 2))\n", + "y = jax.random.normal(next(keys), (32, 8))" + ] + }, + { + "cell_type": "markdown", + "id": "f7cf15ad-093f-4d12-9816-4d8b7fdc5350", + "metadata": {}, + "source": [ + "And we'll use ADAM to update." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "1d3a8c9c-6b64-444c-bcc1-52a4441234ba", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = optax.adam(1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "5849d6d5-ddd4-4a2b-aad4-aae9908be97a", + "metadata": {}, + "source": [ + "# Exponential Moving Average\n", + "\n", + "Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages. " + ] + }, + { + "cell_type": "markdown", + "id": "3b30d52e-7bd7-4a59-9236-f05dfa297d0c", + "metadata": {}, + "source": [ + "## EMA in Pure Jax" + ] + }, + { + "cell_type": "markdown", + "id": "d552ff36-c514-4e66-81a2-282875a16c1b", + "metadata": {}, + "source": [ + "To start, we will see how to keep track of exponential moving averages in raw Jax. Although the raw just implementation is simple and easy to understand, it does not allow for mutable state. " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "5d3ccd55-0734-4e49-856e-cd4e3c491770", + "metadata": {}, + "outputs": [], + "source": [ + "def ema_update(ema, new_val, decay=0.9):\n", + " return decay * ema + (1 - decay) * new_val" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "da835ddc-cf02-452a-8b33-4ba5ecce6a70", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def train_step(opt_state, params, ema_params, x, y):\n", + " loss, grads = jax.value_and_grad(jax_loss_fn)(params, x, y)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " ema_params = tree.map(ema_update, ema_params, params)\n", + " return opt_state, params, ema_params, loss" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1a48a903-7a7a-4041-ae74-c5f4913ec8b6", + "metadata": {}, + "outputs": [], + "source": [ + "params = jax_params(keys)\n", + "opt_state = optimizer.init(params)\n", + "ema_params = params" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "453d0425-b748-436d-a76f-fd8d5e677bcb", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " opt_state, params, ema_params, loss = train_step(opt_state, params, ema_params, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "9ae32560-c7a6-485c-81fb-53f4c5786104", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## EMA in Flax" + ] + }, + { + "cell_type": "markdown", + "id": "6a678a68-840d-4000-8170-8792f4c5d547", + "metadata": {}, + "source": [ + "Now, we can see how to implement an exponential moving average in Flax. The code below is almost identical to the pure jax version above, but because NNX allows for mutable operations, we no longer need to explicitly pass around the full state object. " + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "d4b0f436-0d6f-4dcc-9776-1ef6f994903e", + "metadata": {}, + "outputs": [], + "source": [ + "model = nnx_model(nnx.Rngs(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "2b3418d8-29c2-49ae-9422-0cb4531c5372", + "metadata": {}, + "outputs": [], + "source": [ + "nnx_optimizer = nnx.Optimizer(\n", + " model,\n", + " tx=optimizer,\n", + " wrt=nnx.Param,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "40d71ebc-d7c8-40a3-b31d-95e1ac2e6021", + "metadata": {}, + "outputs": [], + "source": [ + "class Ema(nnx.Module):\n", + " def __init__(self, model):\n", + " self.ema = nnx.merge(*nnx.split(model)) # Make a copy\n", + " def update(self, model):\n", + " self.ema = tree.map(ema_update, self.ema, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "dbe10023-bc26-4f9b-8274-53ffb5f6b0f4", + "metadata": {}, + "outputs": [], + "source": [ + "ema = Ema(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "28b540fe-a298-4999-9652-8c309ebd4547", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def nnx_train_step(model, nnx_optimizer, ema, x, y):\n", + " loss, grads = nnx.value_and_grad(nnx_loss_fn)(model, x, y)\n", + " nnx_optimizer.update(model, grads)\n", + " ema.update(model)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "6b3cb660-8982-4d66-b6e5-80e70e1138f0", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " loss = nnx_train_step(model, nnx_optimizer, ema, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "7d3d4144-2c2b-49e3-a856-1cf61faee9f3", + "metadata": {}, + "source": [ + "# Low Rank Adaptation" + ] + }, + { + "cell_type": "markdown", + "id": "306f65b2-1e81-4313-bcb6-7b0d7f6ebe7d", + "metadata": {}, + "source": [ + "The pattern for adding low rank adaptation to an optimization loop is very similar to adding an exponential moving average. As before, we create a new pytree with the same structure as our model parameters, but here we store low rank additions to these parameters rather than weighted average values. " + ] + }, + { + "cell_type": "markdown", + "id": "44b08f4d-c181-4d73-973c-f6c861fbfab5", + "metadata": {}, + "source": [ + "## Lora in Jax" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "f6efaed5-ca16-45d8-815c-7184db092dca", + "metadata": {}, + "outputs": [], + "source": [ + "base_params = jax_params(keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "d93f0944-bfe1-4765-90e1-5439868dd928", + "metadata": {}, + "outputs": [], + "source": [ + "def init_lora_param(a, k=2):\n", + " if len(a.shape) == 2:\n", + " return {'A': param_init(next(keys), (a.shape[0], k)), 'B': jnp.zeros((k, a.shape[1]))}\n", + " else:\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "6a51c904-5be7-4c0d-8818-a3b8a1dd483d", + "metadata": {}, + "outputs": [], + "source": [ + "jax_lora_params = tree.map(init_lora_param, base_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "a7c19228-34fb-438f-9eb4-2b8b9630e228", + "metadata": {}, + "outputs": [], + "source": [ + "opt_state = optimizer.init(jax_lora_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "cb6d27b1-cb0e-421a-b256-93d7bf788a05", + "metadata": {}, + "outputs": [], + "source": [ + "def apply_lora_param(base_params, lora_params):\n", + " if lora_params is None:\n", + " return base_params\n", + " return base_params + (lora_params['A'] @ lora_params['B'])" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "d1a9373a-0ac0-4bba-8bab-e513b0104ba6", + "metadata": {}, + "outputs": [], + "source": [ + "def jax_lora_loss(lora_params, params, x, y):\n", + " params = tree.map(apply_lora_param, params, lora_params)\n", + " return jax_loss_fn(params, x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "2546b0ac-00a2-494b-851d-e3594408420c", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def lora_train_step(params, lora_params, opt_state, x, y):\n", + " loss, grads = jax.value_and_grad(jax_lora_loss)(lora_params, params, x, y)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " lora_params = optax.apply_updates(lora_params, updates)\n", + " return params, lora_params, opt_state, loss" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "de6037cb-877f-4d8c-887e-fa5898234dc6", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " params, jax_lora_params, opt_state, loss = lora_train_step(params, jax_lora_params, opt_state, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "40843e51-2914-4f28-826e-e16e4110be83", + "metadata": {}, + "source": [ + "## LORA in Flax" + ] + }, + { + "cell_type": "markdown", + "id": "b8ea3128-fd4e-46ae-b9b7-75ece8499bc3", + "metadata": {}, + "source": [ + "If Flax, we just need to wrap the optax optimizer with `nnx.Optimizer` to provide a mutable interface. " + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "4a618426-409c-4041-befa-d7124dcf9450", + "metadata": {}, + "outputs": [], + "source": [ + "nnx_lora_params = tree.map(init_lora_param, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "69d0779e-e9e4-4f01-a1da-a9795955efa8", + "metadata": {}, + "outputs": [], + "source": [ + "def nnx_lora_loss(lora_params, params, x, y):\n", + " params = tree.map(apply_lora_param, params, lora_params)\n", + " return nnx_loss_fn(params, x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "d2358d42-c4c4-4999-a4db-d5082fc97a71", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def nnx_lora_train_step(nnx_model, nnx_lora_params, nnx_optimizer, x, y):\n", + " loss, grads = nnx.value_and_grad(nnx_lora_loss)(nnx_lora_params, nnx_model, x, y)\n", + " nnx_optimizer.update(nnx_lora_params, grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "d1cebe93-ac45-45f2-a656-b0535d9e7ea9", + "metadata": {}, + "outputs": [], + "source": [ + "nnx_lora_optimizer = nnx.Optimizer(\n", + " nnx_lora_params,\n", + " tx=optimizer,\n", + " wrt=nnx.Param,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "8388053d-ed88-4ff0-8ffb-d4e99c7bdbae", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " loss = nnx_lora_train_step(model, nnx_lora_params, nnx_lora_optimizer, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "af26eaff-ba76-4853-9153-e9b75342597d", + "metadata": {}, + "source": [ + "# LBFGS" + ] + }, + { + "cell_type": "markdown", + "id": "8f840a47-e40a-44ea-aadf-63e052a8826c", + "metadata": {}, + "source": [ + "## LBFGS in Jax" + ] + }, + { + "cell_type": "code", + "execution_count": 328, + "id": "141591c8-f8eb-4a07-aa4a-79eb4bfaaaa0", + "metadata": {}, + "outputs": [], + "source": [ + "def make_lbfgs_state(lbfgs):\n", + " params = make_params(keys)\n", + " opt_state = lbfgs.init(params)\n", + " return (params, opt_state)" + ] + }, + { + "cell_type": "code", + "execution_count": 329, + "id": "abb48c12-7ae0-48db-adf2-291451c95e8e", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def train_step(x, y, params, opt_state):\n", + " local_loss = lambda p: loss_fn(p, x, y)\n", + " value_and_grad_fn = optax.value_and_grad_from_state(local_loss)\n", + " loss, grad = value_and_grad_fn(params, state=opt_state)\n", + " updates, opt_state = lbfgs.update(grad, opt_state, params,\n", + " value=loss, grad=grad, value_fn=local_loss)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, opt_state, loss" + ] + }, + { + "cell_type": "code", + "execution_count": 330, + "id": "a127078e-61e7-4f75-9962-16c860f5bed8", + "metadata": {}, + "outputs": [], + "source": [ + "lbfgs = optax.lbfgs()\n", + "params, opt_state = make_lbfgs_state(lbfgs)" + ] + }, + { + "cell_type": "code", + "execution_count": 331, + "id": "a335c4f9-4880-4498-9ecc-4254e3841376", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(x, y, params, opt_state)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "ec10b575-4afd-47e1-bcad-8a6e4635cf74", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## LBFGS in Flax" + ] + }, + { + "cell_type": "markdown", + "id": "fb826e53-5152-4505-b831-37e481cbee6f", + "metadata": {}, + "source": [ + "TODO" + ] + }, + { + "cell_type": "markdown", + "id": "1101a839-0903-446e-b3df-3eba72069da6", + "metadata": {}, + "source": [ + "# Per-Parameter Learning Rates" + ] + }, + { + "cell_type": "markdown", + "id": "4770182a-4974-4902-bf23-3612f7b769ee", + "metadata": {}, + "source": [ + "In some training regimes, you will want to optimize different parameters with different learning rates. " + ] + }, + { + "cell_type": "markdown", + "id": "df9d1592-b66d-45ca-8108-74189b05c11f", + "metadata": {}, + "source": [ + "## In Jax" + ] + }, + { + "cell_type": "markdown", + "id": "a9b4b762-2a18-4825-93b6-aaed2238d2e6", + "metadata": {}, + "source": [ + "First, we map from each leaf to the type of parameter it is (weight or bias)." + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "83bb7b1e-0495-486c-87ca-f8de979eedbb", + "metadata": {}, + "outputs": [], + "source": [ + "params = jax_params(keys)\n", + "param_tys = jax.tree.map_with_path(lambda p, _: p[-1].key, params)" + ] + }, + { + "cell_type": "markdown", + "id": "9d812fbe-eb7a-4ea0-8408-886d70001ec6", + "metadata": {}, + "source": [ + "Next, we create a dictionary giving the learning rates to use for each parameter type." + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "a12ee567-7059-4352-bc8b-d4f023a19fa8", + "metadata": {}, + "outputs": [], + "source": [ + "rates = {'w': optax.adam(1e-3), 'b': optax.adam(1e-2)}" + ] + }, + { + "cell_type": "markdown", + "id": "a8e3bb0e-7911-476a-b7f7-b7c70ee01eb5", + "metadata": {}, + "source": [ + "Finally, we can make a compound optimizers that uses each rate appropriately. " + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "8dbc9cd2-1770-4edb-b497-2f3f6a2ee48a", + "metadata": {}, + "outputs": [], + "source": [ + "joint_optimizer = optax.partition(rates, param_tys)\n", + "opt_state = joint_optimizer.init(params)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "fc7d4347-e230-470b-9cf6-c2a40101ea06", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def train_step(opt_state, params, x, y):\n", + " loss, grads = jax.value_and_grad(jax_loss_fn)(params, x, y)\n", + " updates, opt_state = joint_optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " return opt_state, params, loss" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "d44bd5a1-cf6a-4a00-a24b-df284b615307", + "metadata": {}, + "outputs": [], + "source": [ + "losses = []\n", + "for _ in range(50):\n", + " loss = train_step(opt_state, params, x, y)\n", + " losses.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "cd477e5d-e1ef-4d44-bc8f-c32c3a4075b4", + "metadata": {}, + "source": [ + "## In Flax" + ] + }, + { + "cell_type": "markdown", + "id": "79d5938e-3160-4838-b016-1898741fd506", + "metadata": {}, + "source": [ + "TODO" + ] + }, + { + "cell_type": "markdown", + "id": "75a39c26-a3d2-4847-b2d6-08696c691183", + "metadata": {}, + "source": [ + "# Gradient Accumulation" + ] + }, + { + "cell_type": "markdown", + "id": "2764d19f-9d26-4cdd-bab9-bd422167b90f", + "metadata": {}, + "source": [ + "Just wrap it in" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c041536-e4d5-45ed-9997-2c1227e683a4", + "metadata": {}, + "outputs": [], + "source": [ + "optax.MultiSteps(optimizer, every_k_schedule=3)" + ] + }, + { + "cell_type": "markdown", + "id": "fb8b64bf-1436-45ea-9382-323b2bf584d6", + "metadata": {}, + "source": [ + "# Sharding Optimization State Differently from Parameters" + ] + }, + { + "cell_type": "markdown", + "id": "cb172c07-29e2-49a2-b0f9-159a2fc94cca", + "metadata": {}, + "source": [ + "## Jax Version" + ] + }, + { + "cell_type": "markdown", + "id": "6d6d1dd3-fe6c-4142-ad8e-5827f12ae00f", + "metadata": {}, + "source": [ + "Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency. " + ] + }, + { + "cell_type": "markdown", + "id": "235995f0-26d4-4b4e-a6c3-1e94a55f655a", + "metadata": {}, + "source": [ + "But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimier state to be sharded differently from the parameters themselves." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "995dcc24-7d0c-40ac-9fb5-a959f65f8f6a", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3ce6d29d-6e60-4fe6-9152-a59b765e4d7d", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2, 4), (\"x\", \"y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", + "jax.set_mesh(mesh);" + ] + }, + { + "cell_type": "markdown", + "id": "df4e9497-2a5a-4848-966d-670803fae4f8", + "metadata": {}, + "source": [ + "To do this, we can change our initializer to take a `sharding` argument. " + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c5c3cd10-4faa-4491-bbf5-a9df6440f3e8", + "metadata": {}, + "outputs": [], + "source": [ + "def make_params(keys, sharding):\n", + " return {\n", + " 'w': param_init(next(keys), (2, 8), out_sharding=sharding),\n", + " 'b': jnp.zeros(5)\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "95c72628-9047-4228-b187-f0f604c4bc7d", + "metadata": {}, + "source": [ + "We'll pass in in sharding when we initialize the optimizer, which will shard the optimization state the same way. But when we initialize the model parameters themselves, we won't provide a sharding, allowing for data parallelism. " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "4679d22d-3cf2-4138-9239-582d83b78b4c", + "metadata": {}, + "outputs": [], + "source": [ + "opt = optimizer.init(jax.eval_shape(lambda: make_params(keys, P('x', 'y'))))" + ] + }, + { + "cell_type": "markdown", + "id": "9d08f5ca-6ec7-45f0-b747-684098254d70", + "metadata": {}, + "source": [ + "## Flax Version" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "2eb8790e-1a4a-4154-a12d-afeef5cb2306", + "metadata": {}, + "outputs": [], + "source": [ + "def make_model(sharding):\n", + " return nnx.Linear(2, 8, rngs=nnx.Rngs(0), kernel_init=ft.partial(param_init, out_sharding=sharding));" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "7ee3bf9f-3284-4bbd-8ab9-811bccb87013", + "metadata": {}, + "outputs": [], + "source": [ + "model = make_model(P('x', 'y'))" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "74043647-2a24-4f59-8a4c-1c303f40bcde", + "metadata": {}, + "outputs": [], + "source": [ + "ghost_model = jax.eval_shape(lambda: make_model(P('x', 'y')))" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "10812c50-a18c-4b7a-ad2e-9db544d6fc27", + "metadata": {}, + "outputs": [], + "source": [ + "opt = nnx.Optimizer(ghost_model, optax.adam(1e-3), wrt=nnx.Param)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "04a95502-0331-41bd-a6ef-1c80db2aaa72", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ShapedArray(float32[2@x,8@y])" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.typeof(opt.opt_state[0].mu['kernel'][...])" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/Optimization Cookbook.md b/docs_nnx/guides/Optimization Cookbook.md new file mode 100644 index 000000000..9b26fc830 --- /dev/null +++ b/docs_nnx/guides/Optimization Cookbook.md @@ -0,0 +1,302 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# A Flax Optimization Cookbook + + +# Exponential Moving Average + +Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. It is easy to modify the standard Jax training loop to accomodate calculating exponential moving averages. + + +## EMA in Pure Jax + + +To start, we will see how to keep track of exponential moving averages in raw Jax. Although the raw just implementation is simple and easy to understand, it does not allow for mutable state. + +```python +import jax.numpy as jnp +import jax +from jax import tree +import optax +import itertools as it +import functools as ft +from collections import namedtuple +``` + +```python +state = namedtuple('state', 'params opt_state ema_params') +``` + +```python +keys = map(ft.partial(jax.random.fold_in, jax.random.key(0)), it.count()) +``` + +```python +x = jax.random.normal(next(keys), (32, 2)) +y = jax.random.normal(next(keys), (32, 5)) +``` + +```python +param_init = jax.nn.initializers.lecun_normal() +``` + +```python +def make_params(keys): + return { + 'w': param_init(next(keys), (2, 5)), + 'b': jnp.zeros(5) + } +``` + +```python +optimizer = optax.adam(1e-3) +``` + +```python +def make_state(): + params = make_params(keys) + opt_state = optimizer.init(params) + return state(params, opt_state, params) +``` + +```python +def ema_update(ema, new_val, decay=0.9): + return decay * ema + (1 - decay) * new_val +``` + +```python +def model(params, x): + return x @ params['w'] + params['b'] +``` + +```python +def loss_fn(params, x, y): + return jnp.sum((y - model(params, x))**2) +``` + +```python +@jax.jit +def train_step(x, y, st): + loss, grads = jax.value_and_grad(loss_fn)(st.params, x, y) + updates, opt_state = optimizer.update(grads, st.opt_state) + params = optax.apply_updates(st.params, updates) + ema_params = tree.map(ema_update, st.ema_params, params) + return state(params, st.opt_state, ema_params), loss +``` + +```python +st = make_state() +``` + +```python +losses = [] +for _ in range(50): + st, loss = train_step(x, y, st) + losses.append(loss) +``` + +## EMA in Flax + + +Now, we can see how to implement an exponential moving average in Flax. The code below is almost identical to the pure jax version above, but because NNX allows for mutable operations, we no longer need to explicitly pass around the full state object. + +```python +from flax import nnx +``` + +```python +nnx_model = nnx.Linear(2,5, rngs=nnx.Rngs(42)) +``` + +```python +nnx_optimizer = nnx.Optimizer( + nnx_model, + tx=optimizer, + wrt=nnx.Param, +) +``` + +```python +def nnx_loss_fn(model, x, y): + return jnp.sum((model(x) - y) ** 2) +``` + +```python +class Ema(nnx.Module): + def __init__(self, params): + self.ema = nnx.merge(*nnx.split(nnx_model)) + def update(self, params): + self.ema = tree.map(ema_update, nnx_ema, nnx_model) +``` + +```python +ema = Ema(nnx_model) +``` + +```python +@nnx.jit +def nnx_train_step(nnx_model, nnx_optimizer, ema, x, y): + loss, grads = nnx.value_and_grad(nnx_loss_fn)(nnx_model, x, y) + nnx_optimizer.update(nnx_model, grads) + ema.update(nnx_model) + return loss +``` + +```python +losses = [] +for _ in range(50): + loss = nnx_train_step(nnx_model, nnx_optimizer, ema, x, y) + losses.append(loss) +``` + +# Low Rank Adaptation + + +The pattern for adding low rank adaptation to an optimization loop is very similar to adding an exponential moving average. As before, we create a new pytree with the same structure as our model parameters, but here we store low rank additions to these parameters rather than weighted average values. + + +## Lora in Jax + +```python +def init_lora_param(a, k=2): + if len(a.shape) == 2: + return {'A': param_init(next(keys), (a.shape[0], k)), 'B': jnp.zeros((k, a.shape[1]))} + else: + return None +``` + +```python +params = make_params(keys) +``` + +```python +lora_params = tree.map(init_lora_param, base_params) +``` + +```python +opt_state = optimizer.init(lora_params) +``` + +```python +def apply_lora_param(base_params, lora_params): + if lora_params is None: + return base_params + return base_params + (lora_params['A'] @ lora_params['B']) +``` + +```python +def lora_loss(lora_params, params, x, y): + params = tree.map(apply_lora_param, params, lora_params) + return loss_fn(params, x, y) +``` + +```python +@jax.jit +def lora_train_step(x, y, params, lora_params, opt_state): + loss, grads = jax.value_and_grad(lora_loss)(lora_params, params, x, y) + updates, opt_state = optimizer.update(grads, opt_state) + lora_params = optax.apply_updates(lora_params, updates) + return params, lora_params, opt_state, loss +``` + +```python +losses = [] +for _ in range(50): + params, lora_params, opt_state, loss = lora_train_step(x, y, params, lora_params, opt_state) + losses.append(loss) +``` + +## LORA in Flax + + +If Flax, we just need to wrap the optax optimizer with `nnx.Optimizer` to provide a mutable interface. + +```python +lora_params = tree.map(init_lora_param, nnx_model) +``` + +```python +def nnx_lora_loss(lora_params, params, x, y): + params = tree.map(apply_lora_param, params, lora_params) + return nnx_loss_fn(params, x, y) +``` + +```python +@nnx.jit +def nnx_lora_train_step(nnx_model, nnx_lora_params, nnx_optimizer, x, y): + loss, grads = nnx.value_and_grad(nnx_lora_loss)(nnx_lora_params, nnx_model, x, y) + nnx_optimizer.update(nnx_lora_params, grads) + return loss +``` + +```python +nnx_optimizer = nnx.Optimizer( + lora_params, + tx=optimizer, + wrt=nnx.Param, +) +``` + +```python +losses = [] +for _ in range(50): + loss = nnx_lora_train_step(nnx_model, lora_params, nnx_optimizer, x, y) + losses.append(loss) +``` + +# LBFGS + + +## LBFGS in Jax + +```python +def make_lbfgs_state(lbfgs): + params = make_params(keys) + opt_state = lbfgs.init(params) + return (params, opt_state) +``` + +```python +@jax.jit +def train_step(x, y, params, opt_state): + local_loss = lambda p: loss_fn(p, x, y) + value_and_grad_fn = optax.value_and_grad_from_state(local_loss) + loss, grad = value_and_grad_fn(params, state=opt_state) + updates, opt_state = lbfgs.update(grad, opt_state, params, + value=loss, grad=grad, value_fn=local_loss) + params = optax.apply_updates(params, updates) + return params, opt_state, loss +``` + +```python +lbfgs = optax.lbfgs() +params, opt_state = make_lbfgs_state(lbfgs) +``` + +```python +losses = [] +for _ in range(50): + loss = train_step(x, y, params, opt_state) + losses.append(loss) +``` + +## LBFGS in Flax + + +# TODO +- Per-param LR +- LBFGS +- Opt sharding different from variable sharding +- Gradient accumulation