|
43 | 43 | "from jax import numpy as jnp\n", |
44 | 44 | "from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n", |
45 | 45 | "import optax\n", |
| 46 | + "import flax\n", |
46 | 47 | "from flax import nnx\n", |
47 | 48 | "\n", |
48 | 49 | "# Ignore this if you are already running on a TPU or GPU\n", |
|
56 | 57 | "cell_type": "markdown", |
57 | 58 | "metadata": {}, |
58 | 59 | "source": [ |
59 | | - "Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs. \n", |
| 60 | + "Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs.\n", |
60 | 61 | "\n", |
61 | 62 | "In this guide we use a standard FSDP layout and shard our devices on two axes - `data` and `model`, for doing batch data parallelism and tensor parallelism." |
62 | 63 | ] |
|
75 | 76 | "cell_type": "markdown", |
76 | 77 | "metadata": {}, |
77 | 78 | "source": [ |
78 | | - "> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Check the flag and read on to learn how to use the feature." |
| 79 | + "> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function." |
79 | 80 | ] |
80 | 81 | }, |
81 | 82 | { |
|
84 | 85 | "metadata": {}, |
85 | 86 | "outputs": [], |
86 | 87 | "source": [ |
87 | | - "import flax\n", |
88 | | - "assert flax.config.flax_always_shard_variable is True" |
| 88 | + "nnx.use_eager_sharding(True)\n", |
| 89 | + "assert nnx.using_eager_sharding()" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "markdown", |
| 94 | + "id": "c24144d8", |
| 95 | + "metadata": {}, |
| 96 | + "source": [ |
| 97 | + "The `nnx.use_eager_sharding` function can also be used as a context manager to toggle the eager sharding feature within a specific scope." |
| 98 | + ] |
| 99 | + }, |
| 100 | + { |
| 101 | + "cell_type": "code", |
| 102 | + "execution_count": null, |
| 103 | + "id": "2d849e2e", |
| 104 | + "metadata": {}, |
| 105 | + "outputs": [], |
| 106 | + "source": [ |
| 107 | + "with nnx.use_eager_sharding(False):\n", |
| 108 | + " assert not nnx.using_eager_sharding()" |
| 109 | + ] |
| 110 | + }, |
| 111 | + { |
| 112 | + "cell_type": "markdown", |
| 113 | + "id": "c9f808ec", |
| 114 | + "metadata": {}, |
| 115 | + "source": [ |
| 116 | + "You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way." |
| 117 | + ] |
| 118 | + }, |
| 119 | + { |
| 120 | + "cell_type": "code", |
| 121 | + "execution_count": null, |
| 122 | + "id": "67bbd440", |
| 123 | + "metadata": {}, |
| 124 | + "outputs": [], |
| 125 | + "source": [ |
| 126 | + "nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" |
89 | 127 | ] |
90 | 128 | }, |
91 | 129 | { |
|
256 | 294 | "with jax.set_mesh(auto_mesh):\n", |
257 | 295 | " # Create your input data, sharded along `data` dimension, as in data parallelism\n", |
258 | 296 | " x = jax.device_put(jnp.ones((16, 4)), P('data', None))\n", |
259 | | - " \n", |
| 297 | + "\n", |
260 | 298 | " # Run the model forward function, jitted\n", |
261 | 299 | " y = jax.jit(lambda m, x: m(x))(linear, x)\n", |
262 | 300 | " print(y.sharding.spec) # sharded: ('data', 'model')\n", |
|
313 | 351 | " def create_sublayers(r):\n", |
314 | 352 | " return DotReluDot(depth, r)\n", |
315 | 353 | " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", |
316 | | - " \n", |
| 354 | + "\n", |
317 | 355 | " def __call__(self, x):\n", |
318 | 356 | " def scan_over_layers(x, layer):\n", |
319 | 357 | " return layer(x), None\n", |
|
364 | 402 | " # Model and optimizer\n", |
365 | 403 | " model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n", |
366 | 404 | " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", |
367 | | - " \n", |
| 405 | + "\n", |
368 | 406 | " # The loop\n", |
369 | 407 | " for i in range(5):\n", |
370 | 408 | " model, loss = train_step(model, optimizer, input, label)\n", |
|
496 | 534 | " def create_sublayers(r):\n", |
497 | 535 | " return LogicalDotReluDot(depth, r)\n", |
498 | 536 | " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", |
499 | | - " \n", |
| 537 | + "\n", |
500 | 538 | " def __call__(self, x):\n", |
501 | 539 | " def scan_over_layers(x, layer):\n", |
502 | 540 | " return layer(x), None\n", |
|
617 | 655 | " def create_sublayers(r):\n", |
618 | 656 | " return ExplicitDotReluDot(depth, r)\n", |
619 | 657 | " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", |
620 | | - " \n", |
| 658 | + "\n", |
621 | 659 | " def __call__(self, x):\n", |
622 | 660 | " def scan_over_layers(x, layer):\n", |
623 | 661 | " return layer(x), None\n", |
|
0 commit comments