Skip to content

Commit 7710c30

Browse files
author
Flax Authors
committed
Merge pull request #5079 from samanklesaria:eager_sharding_context
PiperOrigin-RevId: 831520025
2 parents b08cb20 + 719f687 commit 7710c30

File tree

5 files changed

+189
-22
lines changed

5 files changed

+189
-22
lines changed

docs_nnx/guides/flax_gspmd.ipynb

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"from jax import numpy as jnp\n",
4444
"from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n",
4545
"import optax\n",
46+
"import flax\n",
4647
"from flax import nnx\n",
4748
"\n",
4849
"# Ignore this if you are already running on a TPU or GPU\n",
@@ -56,7 +57,7 @@
5657
"cell_type": "markdown",
5758
"metadata": {},
5859
"source": [
59-
"Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs. \n",
60+
"Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs.\n",
6061
"\n",
6162
"In this guide we use a standard FSDP layout and shard our devices on two axes - `data` and `model`, for doing batch data parallelism and tensor parallelism."
6263
]
@@ -75,7 +76,7 @@
7576
"cell_type": "markdown",
7677
"metadata": {},
7778
"source": [
78-
"> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Check the flag and read on to learn how to use the feature."
79+
"> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function."
7980
]
8081
},
8182
{
@@ -84,8 +85,45 @@
8485
"metadata": {},
8586
"outputs": [],
8687
"source": [
87-
"import flax\n",
88-
"assert flax.config.flax_always_shard_variable is True"
88+
"nnx.use_eager_sharding(True)\n",
89+
"assert nnx.using_eager_sharding()"
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"id": "c24144d8",
95+
"metadata": {},
96+
"source": [
97+
"The `nnx.use_eager_sharding` function can also be used as a context manager to toggle the eager sharding feature within a specific scope."
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"id": "2d849e2e",
104+
"metadata": {},
105+
"outputs": [],
106+
"source": [
107+
"with nnx.use_eager_sharding(False):\n",
108+
" assert not nnx.using_eager_sharding()"
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"id": "c9f808ec",
114+
"metadata": {},
115+
"source": [
116+
"You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way."
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"id": "67bbd440",
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)"
89127
]
90128
},
91129
{
@@ -256,7 +294,7 @@
256294
"with jax.set_mesh(auto_mesh):\n",
257295
" # Create your input data, sharded along `data` dimension, as in data parallelism\n",
258296
" x = jax.device_put(jnp.ones((16, 4)), P('data', None))\n",
259-
" \n",
297+
"\n",
260298
" # Run the model forward function, jitted\n",
261299
" y = jax.jit(lambda m, x: m(x))(linear, x)\n",
262300
" print(y.sharding.spec) # sharded: ('data', 'model')\n",
@@ -313,7 +351,7 @@
313351
" def create_sublayers(r):\n",
314352
" return DotReluDot(depth, r)\n",
315353
" self.layers = create_sublayers(rngs.fork(split=num_layers))\n",
316-
" \n",
354+
"\n",
317355
" def __call__(self, x):\n",
318356
" def scan_over_layers(x, layer):\n",
319357
" return layer(x), None\n",
@@ -364,7 +402,7 @@
364402
" # Model and optimizer\n",
365403
" model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n",
366404
" optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n",
367-
" \n",
405+
"\n",
368406
" # The loop\n",
369407
" for i in range(5):\n",
370408
" model, loss = train_step(model, optimizer, input, label)\n",
@@ -496,7 +534,7 @@
496534
" def create_sublayers(r):\n",
497535
" return LogicalDotReluDot(depth, r)\n",
498536
" self.layers = create_sublayers(rngs.fork(split=num_layers))\n",
499-
" \n",
537+
"\n",
500538
" def __call__(self, x):\n",
501539
" def scan_over_layers(x, layer):\n",
502540
" return layer(x), None\n",
@@ -617,7 +655,7 @@
617655
" def create_sublayers(r):\n",
618656
" return ExplicitDotReluDot(depth, r)\n",
619657
" self.layers = create_sublayers(rngs.fork(split=num_layers))\n",
620-
" \n",
658+
"\n",
621659
" def __call__(self, x):\n",
622660
" def scan_over_layers(x, layer):\n",
623661
" return layer(x), None\n",

docs_nnx/guides/flax_gspmd.md

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import jax
2929
from jax import numpy as jnp
3030
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
3131
import optax
32+
import flax
3233
from flax import nnx
3334
3435
# Ignore this if you are already running on a TPU or GPU
@@ -37,7 +38,7 @@ if not jax._src.xla_bridge.backends_are_initialized():
3738
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
3839
```
3940

40-
Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs.
41+
Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs.
4142

4243
In this guide we use a standard FSDP layout and shard our devices on two axes - `data` and `model`, for doing batch data parallelism and tensor parallelism.
4344

@@ -46,11 +47,24 @@ In this guide we use a standard FSDP layout and shard our devices on two axes -
4647
auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))
4748
```
4849

49-
> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Check the flag and read on to learn how to use the feature.
50+
> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function.
5051
5152
```{code-cell} ipython3
52-
import flax
53-
assert flax.config.flax_always_shard_variable is True
53+
nnx.use_eager_sharding(True)
54+
assert nnx.using_eager_sharding()
55+
```
56+
57+
The `nnx.use_eager_sharding` function can also be used as a context manager to toggle the eager sharding feature within a specific scope.
58+
59+
```{code-cell} ipython3
60+
with nnx.use_eager_sharding(False):
61+
assert not nnx.using_eager_sharding()
62+
```
63+
64+
You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way.
65+
66+
```{code-cell} ipython3
67+
nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)
5468
```
5569

5670
## Shard a single-array model
@@ -107,7 +121,7 @@ You should still make sure to `jax.jit` for maximum performance, and also to exp
107121
with jax.set_mesh(auto_mesh):
108122
# Create your input data, sharded along `data` dimension, as in data parallelism
109123
x = jax.device_put(jnp.ones((16, 4)), P('data', None))
110-
124+
111125
# Run the model forward function, jitted
112126
y = jax.jit(lambda m, x: m(x))(linear, x)
113127
print(y.sharding.spec) # sharded: ('data', 'model')
@@ -153,7 +167,7 @@ class MultiDotReluDot(nnx.Module):
153167
def create_sublayers(r):
154168
return DotReluDot(depth, r)
155169
self.layers = create_sublayers(rngs.fork(split=num_layers))
156-
170+
157171
def __call__(self, x):
158172
def scan_over_layers(x, layer):
159173
return layer(x), None
@@ -182,7 +196,7 @@ with jax.set_mesh(auto_mesh):
182196
# Model and optimizer
183197
model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))
184198
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
185-
199+
186200
# The loop
187201
for i in range(5):
188202
model, loss = train_step(model, optimizer, input, label)
@@ -266,7 +280,7 @@ class LogicalMultiDotReluDot(nnx.Module):
266280
def create_sublayers(r):
267281
return LogicalDotReluDot(depth, r)
268282
self.layers = create_sublayers(rngs.fork(split=num_layers))
269-
283+
270284
def __call__(self, x):
271285
def scan_over_layers(x, layer):
272286
return layer(x), None
@@ -354,7 +368,7 @@ class ExplicitMultiDotReluDot(nnx.Module):
354368
def create_sublayers(r):
355369
return ExplicitDotReluDot(depth, r)
356370
self.layers = create_sublayers(rngs.fork(split=num_layers))
357-
371+
358372
def __call__(self, x):
359373
def scan_over_layers(x, layer):
360374
return layer(x), None

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@
203203
from .variablelib import register_variable_name as register_variable_name
204204
from .variablelib import use_refs as use_refs
205205
from .variablelib import using_refs as using_refs
206+
from .variablelib import use_eager_sharding as use_eager_sharding
207+
from .variablelib import using_eager_sharding as using_eager_sharding
206208
from .visualization import display as display
207209
from .extract import to_tree as to_tree
208210
from .extract import from_tree as from_tree

flax/nnx/variablelib.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,103 @@
6464
@dataclasses.dataclass
6565
class VariableContext(threading.local):
6666
mutable_variable_stack: list[bool] = dataclasses.field(default_factory=list)
67+
eager_shard_stack: list[bool] = dataclasses.field(default_factory=list)
68+
6769

6870

6971
VARIABLE_CONTEXT = VariableContext()
7072

73+
class UseEagerShardContext:
74+
def __init__(self, prev_value: bool | None, new_value: bool):
75+
self.prev_value: bool | None = prev_value
76+
self.new_value: bool = new_value
77+
78+
def __enter__(self):
79+
if self.prev_value is not None:
80+
VARIABLE_CONTEXT.eager_shard_stack.insert(-1, self.prev_value)
81+
82+
def __exit__(self, exc_type, exc_value, traceback):
83+
VARIABLE_CONTEXT.eager_shard_stack.pop()
84+
85+
def __call__(self, f: F) -> F:
86+
# undo eager stack change
87+
VARIABLE_CONTEXT.eager_shard_stack.pop()
88+
if self.prev_value is not None:
89+
VARIABLE_CONTEXT.eager_shard_stack.append(self.prev_value)
90+
91+
@functools.wraps(f)
92+
def use_eager_sharding_wrapper(*args, **kwargs):
93+
VARIABLE_CONTEXT.eager_shard_stack.append(self.new_value)
94+
try:
95+
return f(*args, **kwargs)
96+
finally:
97+
VARIABLE_CONTEXT.eager_shard_stack.pop()
98+
99+
return use_eager_sharding_wrapper # type: ignore[return-value]
100+
101+
def using_eager_sharding() -> bool:
102+
"""Returns whether Variables are using eager sharding by default.
103+
104+
Example::
105+
106+
>>> from flax import nnx
107+
>>> nnx.use_eager_sharding(True)
108+
<...>
109+
>>> nnx.using_eager_sharding()
110+
True
111+
>>> nnx.use_eager_sharding(False)
112+
<...>
113+
>>> nnx.using_eager_sharding()
114+
False
115+
116+
117+
Returns:
118+
A boolean indicating if Variables are using eager sharding by default.
119+
"""
120+
do_eager_sharding = config.flax_always_shard_variable
121+
if VARIABLE_CONTEXT.eager_shard_stack:
122+
do_eager_sharding = VARIABLE_CONTEXT.eager_shard_stack[-1]
123+
return do_eager_sharding
124+
125+
def use_eager_sharding(value: bool, /):
126+
"""Sets whether Variables should use eager sharding by default or not.
127+
128+
Example usage::
129+
130+
>>> from flax import nnx
131+
>>> # Use eager sharding by default
132+
>>> nnx.use_eager_sharding(True)
133+
<...>
134+
>>> # Variable will now use eager sharding
135+
>>> nnx.using_eager_sharding()
136+
True
137+
138+
It can also be used as a context manager to temporarily
139+
change the default behavior for a block of code::
140+
141+
>>> nnx.use_eager_sharding(False)
142+
<...>
143+
>>> with nnx.use_eager_sharding(True):
144+
... nnx.using_eager_sharding()
145+
True
146+
>>> # it will reset outside
147+
>>> v = nnx.Variable(jax.numpy.ones((2, 3)))
148+
>>> nnx.using_eager_sharding()
149+
False
150+
151+
Args:
152+
value: A boolean indicating if Variables should use eager sharding by default.
153+
154+
Returns:
155+
A context manager that resets the context to the previous value.
156+
"""
157+
if VARIABLE_CONTEXT.eager_shard_stack:
158+
prev_value = VARIABLE_CONTEXT.eager_shard_stack[-1]
159+
VARIABLE_CONTEXT.eager_shard_stack[-1] = value
160+
else:
161+
prev_value = None
162+
VARIABLE_CONTEXT.eager_shard_stack.append(value)
163+
return UseEagerShardContext(prev_value, value)
71164

72165
def using_refs() -> bool:
73166
"""Returns whether Variables are using ArrayRefs by default.
@@ -289,10 +382,7 @@ def __init__(
289382
value = self.create_value(self.raw_value)
290383

291384
# shard the value if applicable
292-
do_eager_sharding = config.flax_always_shard_variable
293-
if 'eager_sharding' in metadata:
294-
do_eager_sharding = metadata['eager_sharding']
295-
if do_eager_sharding and 'sharding_names' in metadata:
385+
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_names' in metadata:
296386
value = core_spmd.shard_value(
297387
value, metadata['sharding_names'], metadata.get('sharding_rules', None),
298388
metadata.get('mesh', None))

tests/nnx/spmd_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,21 @@ def __call__(self, x: jax.Array):
192192
self.assertEqual(badds, [(0, 'layers'), (0, 'layers')])
193193
self.assertEqual(bremoves, [(0, 'layers')])
194194

195+
196+
@parameterized.product(use_eager_sharding=[True, False])
197+
def test_eager_sharding_context(self, use_eager_sharding):
198+
rngs = nnx.Rngs(0)
199+
with nnx.use_eager_sharding(use_eager_sharding):
200+
mesh = jax.make_mesh(((2, 2)), ("data", "model"))
201+
with jax.set_mesh(mesh):
202+
w = nnx.Param(
203+
rngs.lecun_normal()((4, 8)),
204+
sharding_names=(None, 'model'))
205+
if use_eager_sharding:
206+
assert has_sharding_spec(w)
207+
else:
208+
assert not has_sharding_spec(w)
209+
195210
@parameterized.product(use_ref=[True, False])
196211
def test_logical_rules(self, use_ref):
197212
self.enter_context(nnx.use_refs(use_ref))
@@ -302,6 +317,14 @@ def test_explicit_sharding_mesh_context(self):
302317
P('row', 'col'),
303318
)
304319

320+
def has_sharding_spec(array):
321+
sharding = array.sharding
322+
if hasattr(sharding, 'spec'):
323+
# For NamedSharding or PositionalSharding
324+
return sharding.spec is not None and any(
325+
s is not None for s in sharding.spec
326+
)
327+
return False
305328

306329
if __name__ == '__main__':
307330
absltest.main()

0 commit comments

Comments
 (0)