2121import threading
2222import typing as tp
2323from typing import Any
24+ import warnings
2425from flax import config
2526
2627import jax
@@ -375,16 +376,20 @@ def __init__(
375376 metadata ['on_remove_axis' ] = var_t .on_remove_axis
376377
377378 if 'sharding' in metadata :
378- metadata ['sharding_names' ] = metadata .pop ('sharding' )
379+ metadata ['sharding_metadata' ] = metadata .pop ('sharding' )
380+
381+ if 'sharding_names' in metadata : # for bw compat
382+ warnings .warn ("'sharding_names' is deprecated. Use 'sharding_metadata' instead." , DeprecationWarning )
383+ metadata ['sharding_metadata' ] = metadata .pop ('sharding_names' )
379384
380385 object .__setattr__ (self , '_var_metadata' , metadata )
381386 # run create_value hooks
382387 value = self .create_value (self .raw_value )
383388
384389 # shard the value if applicable
385- if metadata .get ('eager_sharding' , using_eager_sharding ()) and 'sharding_names ' in metadata :
390+ if metadata .get ('eager_sharding' , using_eager_sharding ()) and 'sharding_metadata ' in metadata :
386391 value = core_spmd .shard_value (
387- value , metadata ['sharding_names ' ], metadata .get ('sharding_rules' , None ),
392+ value , metadata ['sharding_metadata ' ], metadata .get ('sharding_rules' , None ),
388393 metadata .get ('mesh' , None ))
389394
390395 # Create the ref out of the array value
@@ -394,6 +399,9 @@ def __init__(
394399 object .__setattr__ (self , 'raw_value' , value )
395400
396401 def __getattr__ (self , name : str ) -> tp .Any :
402+ if name == 'sharding_names' : # for backward compatibility
403+ warnings .warn ("'sharding_names' is deprecated. Use 'sharding_metadata' instead." , DeprecationWarning )
404+ return self .sharding_metadata
397405 if name in object .__getattribute__ (self , '_var_metadata' ):
398406 return self ._var_metadata [name ]
399407 return getattr (self .raw_value , name )
0 commit comments