Skip to content

Commit a734ef9

Browse files
committed
Rename sharding_names to sharding_metadata
1 parent 7710c30 commit a734ef9

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

flax/core/spmd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def shard_value(value, sharding_names, sharding_rules, mesh):
4545
f' with annotation {sharding_names=}. '
4646
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
4747
pspec = get_pspec(sharding_names, sharding_rules)
48+
if isinstance(sharding_names, NamedSharding) and mesh is not None:
49+
assert sharding_names.mesh == mesh
4850
if mesh is not None:
4951
return _apply_sharding(value, NamedSharding(mesh, pspec))
5052
return _apply_sharding(value, pspec)
@@ -107,8 +109,10 @@ def composite_rules(rule1, rule2):
107109

108110

109111
def from_sharding_rules(
110-
sharding: Sharding, sharding_rules: LogicalRules
112+
sharding, sharding_rules: LogicalRules
111113
) -> Sharding:
114+
if isinstance(sharding, NamedSharding):
115+
sharding = sharding.spec
112116
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
113117
return tuple(
114118
rules[str(s)] if (s and str(s) in rules) else s for s in sharding

flax/nnx/variablelib.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import typing as tp
2323
from typing import Any
24+
import warnings
2425
from flax import config
2526

2627
import jax
@@ -375,16 +376,20 @@ def __init__(
375376
metadata['on_remove_axis'] = var_t.on_remove_axis
376377

377378
if 'sharding' in metadata:
378-
metadata['sharding_names'] = metadata.pop('sharding')
379+
metadata['sharding_metadata'] = metadata.pop('sharding')
380+
381+
if 'sharding_names' in metadata: # for bw compat
382+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
383+
metadata['sharding_metadata'] = metadata.pop('sharding_names')
379384

380385
object.__setattr__(self, '_var_metadata', metadata)
381386
# run create_value hooks
382387
value = self.create_value(self.raw_value)
383388

384389
# shard the value if applicable
385-
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_names' in metadata:
390+
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_metadata' in metadata:
386391
value = core_spmd.shard_value(
387-
value, metadata['sharding_names'], metadata.get('sharding_rules', None),
392+
value, metadata['sharding_metadata'], metadata.get('sharding_rules', None),
388393
metadata.get('mesh', None))
389394

390395
# Create the ref out of the array value
@@ -394,6 +399,9 @@ def __init__(
394399
object.__setattr__(self, 'raw_value', value)
395400

396401
def __getattr__(self, name: str) -> tp.Any:
402+
if name == 'sharding_names': # for backward compatibility
403+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
404+
return self.sharding_metadata
397405
if name in object.__getattribute__(self, '_var_metadata'):
398406
return self._var_metadata[name]
399407
return getattr(self.raw_value, name)

0 commit comments

Comments
 (0)