Skip to content

Commit 1cc5511

Browse files
committed
Rename sharding_names to sharding_metadata
1 parent 7710c30 commit 1cc5511

File tree

11 files changed

+59
-45
lines changed

11 files changed

+59
-45
lines changed

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import dataclasses
1616
import os
17+
18+
from jax._src import sharding
1719
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
1820

1921
from matplotlib import pyplot as plt
@@ -56,15 +58,15 @@ class MLP(nnx.Module):
5658
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
5759
self.w1 = nnx.Param(
5860
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
59-
sharding_names=mesh_rules('embed', 'mlp'),
61+
sharding_metadata=mesh_rules('embed', 'mlp'),
6062
)
6163
self.b1 = nnx.Param(
6264
jnp.zeros((dmid,)),
63-
sharding_names=mesh_rules('mlp'),
65+
sharding_metadata=mesh_rules('mlp'),
6466
)
6567
self.w2 = nnx.Param(
6668
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
67-
sharding_names=mesh_rules('embed', 'mlp'),
69+
sharding_metadata=mesh_rules('embed', 'mlp'),
6870
)
6971

7072
def __call__(self, x: jax.Array):

flax/core/meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
300300
def to_nnx_metadata(self) -> dict[str, Any]:
301301
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
302302
metadata = dict(vars(self))
303-
metadata['sharding_names'] = metadata.pop('names')
303+
metadata['sharding_metadata'] = metadata.pop('names')
304304
return metadata
305305

306306
@classmethod
307307
def from_nnx_metadata(cls, metadata: dict[str, Any]):
308308
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
309-
metadata['names'] = metadata.pop('sharding_names')
309+
metadata['names'] = metadata.pop('sharding_metadata')
310310
fields = {x.name for x in dataclasses.fields(cls)}
311311
return cls(**{k: v for k, v in metadata.items() if k in fields})
312312

flax/core/spmd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def shard_value(value, sharding_names, sharding_rules, mesh):
4545
f' with annotation {sharding_names=}. '
4646
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
4747
pspec = get_pspec(sharding_names, sharding_rules)
48+
if isinstance(sharding_names, NamedSharding) and mesh is not None:
49+
assert sharding_names.mesh == mesh
4850
if mesh is not None:
4951
return _apply_sharding(value, NamedSharding(mesh, pspec))
5052
return _apply_sharding(value, pspec)
@@ -107,8 +109,10 @@ def composite_rules(rule1, rule2):
107109

108110

109111
def from_sharding_rules(
110-
sharding: Sharding, sharding_rules: LogicalRules
112+
sharding, sharding_rules: LogicalRules
111113
) -> Sharding:
114+
if isinstance(sharding, NamedSharding):
115+
sharding = sharding.spec
112116
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
113117
return tuple(
114118
rules[str(s)] if (s and str(s) in rules) else s for s in sharding

flax/linen/spmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,15 @@ def to_nnx_metadata(self) -> dict[str, Any]:
290290
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
291291
metadata = vars(self)
292292
if 'names' in metadata:
293-
metadata['sharding_names'] = metadata.pop('names')
293+
metadata['sharding_metadata'] = metadata.pop('names')
294294
if 'rules' in metadata:
295295
metadata['sharding_rules'] = metadata.pop('rules')
296296
return metadata
297297

298298
@classmethod
299299
def from_nnx_metadata(cls, metadata: dict[str, Any]):
300300
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
301-
metadata['names'] = metadata.pop('sharding_names')
301+
metadata['names'] = metadata.pop('sharding_metadata')
302302
metadata['rules'] = metadata.pop('sharding_rules')
303303
fields = {x.name for x in dataclasses.fields(cls)}
304304
return cls(**{k: v for k, v in metadata.items() if k in fields})

flax/nnx/spmd.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def insert_field(fields, index, value):
4545
def _add_axis(x: tp.Any):
4646
if isinstance(x, variablelib.Variable):
4747
metadata = x.get_metadata()
48-
if 'sharding_names' in metadata and metadata['sharding_names']:
49-
sharding = metadata['sharding_names']
50-
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
48+
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
49+
sharding = metadata['sharding_metadata']
50+
x.set_metadata(sharding_metadata=insert_field(sharding, index, axis_name))
5151

5252
for k, v in other_meta.items():
5353
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
@@ -74,9 +74,9 @@ def remove_field(fields, index, value):
7474

7575
def _remove_axis(x: tp.Any):
7676
if isinstance(x, variablelib.Variable):
77-
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
77+
if hasattr(x, 'sharding_metadata') and x.sharding_metadata is not None:
7878
x.set_metadata(
79-
sharding_names=remove_field(x.sharding_names, index, axis_name)
79+
sharding_metadata=remove_field(x.sharding_metadata, index, axis_name)
8080
)
8181

8282
for k, v in other_meta.items():
@@ -119,7 +119,7 @@ def with_partitioning(
119119
"""A wrapper over any initializer to add sharding annotation data to a `Variable`."""
120120
return variablelib.with_metadata(
121121
initializer,
122-
sharding_names=sharding,
122+
sharding_metadata=sharding,
123123
mesh=mesh,
124124
**metadata,
125125
)
@@ -128,8 +128,8 @@ def with_partitioning(
128128
def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
129129
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
130130
metadata = v.get_metadata()
131-
if 'sharding_names' in metadata and metadata['sharding_names']:
132-
sharding = metadata['sharding_names']
131+
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
132+
sharding = metadata['sharding_metadata']
133133
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
134134
context_rules = core_spmd.get_logical_axis_rules()
135135
local_rules = metadata.get('sharding_rules', ())
@@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh):
174174
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
175175
abs_state, get_named_sharding(abs_state, mesh)
176176
)
177-
return gdef, abs_state
177+
return gdef, abs_state

flax/nnx/variablelib.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import typing as tp
2323
from typing import Any
24+
import warnings
2425
from flax import config
2526

2627
import jax
@@ -375,16 +376,20 @@ def __init__(
375376
metadata['on_remove_axis'] = var_t.on_remove_axis
376377

377378
if 'sharding' in metadata:
378-
metadata['sharding_names'] = metadata.pop('sharding')
379+
metadata['sharding_metadata'] = metadata.pop('sharding')
380+
381+
if 'sharding_names' in metadata: # for bw compat
382+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
383+
metadata['sharding_metadata'] = metadata.pop('sharding_names')
379384

380385
object.__setattr__(self, '_var_metadata', metadata)
381386
# run create_value hooks
382387
value = self.create_value(self.raw_value)
383388

384389
# shard the value if applicable
385-
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_names' in metadata:
390+
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_metadata' in metadata:
386391
value = core_spmd.shard_value(
387-
value, metadata['sharding_names'], metadata.get('sharding_rules', None),
392+
value, metadata['sharding_metadata'], metadata.get('sharding_rules', None),
388393
metadata.get('mesh', None))
389394

390395
# Create the ref out of the array value
@@ -394,6 +399,9 @@ def __init__(
394399
object.__setattr__(self, 'raw_value', value)
395400

396401
def __getattr__(self, name: str) -> tp.Any:
402+
if name == 'sharding_names': # for backward compatibility
403+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
404+
return self.sharding_metadata
397405
if name in object.__getattribute__(self, '_var_metadata'):
398406
return self._var_metadata[name]
399407
return getattr(self.raw_value, name)

tests/nnx/bridge/wrappers_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,15 @@ def create_sharded_nnx_module(x):
174174
self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned)
175175
self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned)
176176
self.assertIsInstance(nnx_model.kernel, nnx.Variable)
177-
assert nnx_model.kernel.sharding_names == ('in', 'out')
177+
assert nnx_model.kernel.sharding_metadata == ('in', 'out')
178178
assert nnx_model.kernel[...].sharding.is_equivalent_to(
179179
jax.sharding.NamedSharding(
180180
self.mesh, jax.sharding.PartitionSpec('in', 'out')
181181
),
182182
ndim=2,
183183
), f'{nnx_model.kernel[...].sharding = }'
184184

185-
assert nnx_model.bias.sharding_names == ('out-alias',)
185+
assert nnx_model.bias.sharding_metadata == ('out-alias',)
186186
assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),)
187187
assert nnx_model.bias[...].sharding.is_equivalent_to(
188188
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out')),
@@ -410,7 +410,7 @@ def test_nnx_to_linen_metadata(self):
410410
pspec_tree = nn.get_partition_spec(variables)
411411
assert y.shape == (1, 64)
412412
self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta)
413-
assert variables['params']['kernel'].metadata['sharding_names'] == ('in', 'out')
413+
assert variables['params']['kernel'].metadata['sharding_metadata'] == ('in', 'out')
414414
self.assertEqual(pspec_tree['params']['kernel'],
415415
jax.sharding.PartitionSpec('in', 'out'))
416416
np.testing.assert_allclose(y, x @ variables['params']['kernel'].value)
@@ -519,8 +519,8 @@ def __call__(self, x):
519519
w, b = model.inner.dot['w'], model.inner.b
520520
np.testing.assert_allclose(model(x), x @ w + b)
521521
self.assertIsInstance(w, nnx.Param)
522-
assert hasattr(w, 'sharding_names') and w.sharding_names == ('in', 'out')
523-
assert hasattr(b, 'sharding_names') and b.sharding_names == ('out-alias', )
522+
assert hasattr(w, 'sharding_metadata') and w.sharding_metadata == ('in', 'out')
523+
assert hasattr(b, 'sharding_metadata') and b.sharding_metadata == ('out-alias', )
524524

525525
def test_linen_nnx_linen(self):
526526
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without

tests/nnx/nn/linear_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def test(self, module_args_kwargs_initargs):
393393
kwargs = {"rngs": nnx.Rngs(0)}
394394
sharding_names = ("din", "dout")
395395
metadata_kwargs = {
396-
f"{key}_metadata": {"sharding_names": sharding_names[:le]}
396+
f"{key}_metadata": {"sharding_metadata": sharding_names[:le]}
397397
for key, le, _ in metadata_argnames
398398
}
399399

@@ -410,8 +410,8 @@ def test(self, module_args_kwargs_initargs):
410410
for attr_name, param_name in attrs:
411411
attr = getattr(module, attr_name) if attr_name is not None else module
412412
param = getattr(attr, param_name)
413-
self.assertIsNotNone(param.sharding_names)
414-
self.assertEqual(param.sharding_names, sharding_names[:le])
413+
self.assertIsNotNone(param.sharding_metadata)
414+
self.assertEqual(param.sharding_metadata, sharding_names[:le])
415415

416416

417417
if __name__ == '__main__':

tests/nnx/optimizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_sharding_propagation(self):
9191
state = nnx.state(optimizer)
9292
partition_spec = nnx.get_partition_spec(state)
9393

94-
self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b'))
94+
self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_metadata, ('a', 'b'))
9595
self.assertEqual(
9696
partition_spec['opt_state'][0]['mu']['kernel'].value,
9797
jax.sharding.PartitionSpec('a', 'b'),

tests/nnx/spmd_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(self, rngs: nnx.Rngs):
139139
4,
140140
kernel_init=nnx.with_metadata(
141141
nnx.initializers.lecun_normal(),
142-
sharding_names=('din', 'dout'),
142+
sharding_metadata=('din', 'dout'),
143143
nickname=('in', 'out'),
144144
on_add_axis=lambda _, idx, name: kadds.append((idx, name)),
145145
on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)),
@@ -160,7 +160,7 @@ def __call__(self, x: jax.Array):
160160
x = self.linear(x)
161161
# test sharding layer axes is not present inside scan
162162
test.assertEqual(self.linear.kernel.shape, (4, 4))
163-
test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout'))
163+
test.assertEqual(self.linear.kernel.sharding_metadata, ('din', 'dout'))
164164
# at least a remove_axis was already called to remove the layer axis
165165
test.assertEqual(kremoves[-1], (0, 'layers'))
166166
test.assertEqual(bremoves[-1], (0, 'layers'))
@@ -175,7 +175,7 @@ def __call__(self, x: jax.Array):
175175
with jax.set_mesh(mesh):
176176
m = MLP(rngs=nnx.Rngs(0))
177177
self.assertEqual(m.linear.kernel.shape, (5, 4, 4))
178-
self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout'))
178+
self.assertEqual(m.linear.kernel.sharding_metadata, ('layers', 'din', 'dout'))
179179
self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out'))
180180
self.assertEqual(m.linear.bias.shape, (5, 4))
181181
# One add_axis called to add the `nnx.vmap` dimension
@@ -201,7 +201,7 @@ def test_eager_sharding_context(self, use_eager_sharding):
201201
with jax.set_mesh(mesh):
202202
w = nnx.Param(
203203
rngs.lecun_normal()((4, 8)),
204-
sharding_names=(None, 'model'))
204+
sharding_metadata=(None, 'model'))
205205
if use_eager_sharding:
206206
assert has_sharding_spec(w)
207207
else:
@@ -273,7 +273,7 @@ def test_explicit_sharding(self):
273273
)
274274
v = nnx.Variable(
275275
jnp.ones((4, 4)),
276-
sharding_names=('row', 'col'),
276+
sharding_metadata=('row', 'col'),
277277
mesh=mesh,
278278
)
279279
self.assertEqual(v.sharding.mesh, mesh)
@@ -291,7 +291,7 @@ def test_explicit_sharding_disable_jit(self):
291291
with jax.disable_jit(True):
292292
v = nnx.Variable(
293293
jnp.ones((4, 4)),
294-
sharding_names=('row', 'col'),
294+
sharding_metadata=('row', 'col'),
295295
mesh=mesh,
296296
)
297297
self.assertEqual(v.sharding.mesh, mesh)
@@ -309,7 +309,7 @@ def test_explicit_sharding_mesh_context(self):
309309
with jax.set_mesh(mesh):
310310
v = nnx.Variable(
311311
jnp.ones((4, 4)),
312-
sharding_names=('row', 'col'),
312+
sharding_metadata=('row', 'col'),
313313
)
314314
self.assertEqual(v.sharding.mesh, mesh)
315315
self.assertEqual(

0 commit comments

Comments
 (0)