diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index fcb15f060..5adf35561 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -19,9 +19,6 @@ from flax.typing import Initializer as Initializer from .bridge import wrappers as wrappers -from .bridge.variables import ( - register_variable_name_type_pair as register_variable_name_type_pair, -) from .filterlib import WithTag as WithTag from .filterlib import PathContains as PathContains from .filterlib import OfType as OfType @@ -163,6 +160,9 @@ from .variablelib import VariableState as VariableState from .variablelib import VariableMetadata as VariableMetadata from .variablelib import with_metadata as with_metadata +from .variablelib import variable_type_from_name as variable_type_from_name +from .variablelib import variable_name_from_type as variable_name_from_type +from .variablelib import register_variable_name_type_pair as register_variable_name_type_pair from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 7ed1b46ab..a2c8a4a81 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -19,5 +19,4 @@ from .wrappers import lazy_init as lazy_init from .wrappers import ToLinen as ToLinen from .wrappers import to_linen as to_linen -from .variables import NNXMeta as NNXMeta -from .variables import register_variable_name_type_pair as register_variable_name_type_pair \ No newline at end of file +from .variables import NNXMeta as NNXMeta \ No newline at end of file diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 121bb98eb..b3392c865 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -20,7 +20,7 @@ from flax.core import meta from flax.nnx import spmd from flax.nnx import traversals -from flax.nnx import variablelib as variableslib +from flax.nnx import variablelib from flax.nnx.module import GraphDef import typing as tp @@ -29,56 +29,9 @@ B = TypeVar('B') -####################################################### -### Variable type <-> Linen collection name mapping ### -####################################################### -# Assumption: the mapping is 1-1 and unique. - -VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {} - - -def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]: - """Given a Linen-style collection name, get or create its corresponding NNX Variable type.""" - if name not in VariableTypeCache: - VariableTypeCache[name] = type(name, (variableslib.Variable,), {}) - return VariableTypeCache[name] - - -def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: - """Given an NNX Variable type, get or create its Linen-style collection name. - - Should output the exact inversed result of `variable_type()`.""" - for name, t in VariableTypeCache.items(): - if typ == t: - return name - name = typ.__name__ - if name in VariableTypeCache: - raise ValueError( - 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' - 'It cannot be linked with this type {typ}.' - ) - register_variable_name_type_pair(name, typ) - return name - - -def register_variable_name_type_pair(name, typ, overwrite = False): - """Register a pair of Linen collection name and its NNX type.""" - if not overwrite and name in VariableTypeCache: - raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' - 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.') - VariableTypeCache[name] = typ - - -# add known variable type names -register_variable_name_type_pair('params', variableslib.Param) -register_variable_name_type_pair('batch_stats', variableslib.BatchStat) -register_variable_name_type_pair('cache', variableslib.Cache) -register_variable_name_type_pair('intermediates', variableslib.Intermediate) - - def sort_variable_types(types: tp.Iterable[type]): def _variable_parents_count(t: type): - return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable)) + return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) parent_count = {t: _variable_parents_count(t) for t in types} return sorted(types, key=lambda t: -parent_count[t]) @@ -91,7 +44,7 @@ def _variable_parents_count(t: type): class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): """Default Flax metadata class for `nnx.VariableState`.""" - var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) + var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) metadata: dict[str, tp.Any] = struct.field(pytree_node=False) @@ -114,11 +67,11 @@ def get_partition_spec(self) -> jax.sharding.PartitionSpec: nnx_var = self.to_nnx_variable().to_state() return spmd.get_partition_spec(nnx_var).value - def to_nnx_variable(self) -> variableslib.Variable: + def to_nnx_variable(self) -> variablelib.Variable: return self.var_type(self.value, **self.metadata) -def is_vanilla_variable(vs: variableslib.VariableState) -> bool: +def is_vanilla_variable(vs: variablelib.VariableState) -> bool: """A variables state is vanilla if its metadata is essentially blank. Returns False only if it has non-empty hooks or any non-built-in attribute. @@ -132,7 +85,7 @@ def is_vanilla_variable(vs: variableslib.VariableState) -> bool: return True -def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: +def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: linen_type = metadata['linen_meta_type'] @@ -151,9 +104,9 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str: return str(keypath[0].key) -def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: +def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variablelib.Variable: """Convert a Linen variable to an NNX variable.""" - vtype = variable_type(col) + vtype = variablelib.variable_type_from_name(col) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' return x.to_nnx_variable() @@ -196,14 +149,14 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: for kp, v in traversals.flatten_mapping( nnx_attrs, is_leaf=lambda _, x: isinstance( - x, variableslib.Variable | variableslib.VariableState | GraphDef + x, variablelib.Variable | variablelib.VariableState | GraphDef ), ).items(): - if isinstance(v, variableslib.Variable): - col_name = variable_type_name(type(v)) + if isinstance(v, variablelib.Variable): + col_name = variablelib.variable_name_from_type(type(v)) v = to_linen_var(v.to_state()) - elif isinstance(v, variableslib.VariableState): - col_name = variable_type_name(v.type) + elif isinstance(v, variablelib.VariableState): + col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) else: col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index ab673644c..f5f6f43ba 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -21,6 +21,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import variablelib from flax.nnx.bridge import variables as bv from flax.nnx.module import GraphDef, Module from flax.nnx.object import Object @@ -271,7 +272,7 @@ def _update_variables(self, module): # Each variable type goes to its own linen collection, and # each attribute goes to its own linen variable for typ, state in zip(types, state_by_types): - collection = bv.variable_type_name(typ) + collection = variablelib.variable_name_from_type(typ) if self.is_mutable_collection(collection): for k, v in state.raw_mapping.items(): v = jax.tree.map(bv.to_linen_var, v, diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index b2c066096..83da9c7a0 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -44,8 +44,6 @@ AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] -VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} - @dataclasses.dataclass @@ -966,3 +964,51 @@ def split_flat_state( ) return flat_states + + + +################################################### +### Variable type/class <-> string name mapping ### +################################################### +# Assumption: the mapping is 1-1 and unique. + +VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} + + +def variable_type_from_name(name: str) -> tp.Type[Variable[tp.Any]]: + """Given a Linen-style collection name, get or create its NNX Variable class.""" + if name not in VariableTypeCache: + VariableTypeCache[name] = type(name, (Variable,), {}) + return VariableTypeCache[name] + + +def variable_name_from_type(typ: tp.Type[Variable[tp.Any]]) -> str: + """Given an NNX Variable type, get its Linen-style collection name. + + Should output the exact inversed result of `variable_type_from_name()`.""" + for name, t in VariableTypeCache.items(): + if typ == t: + return name + name = typ.__name__ + if name in VariableTypeCache: + raise ValueError( + 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' + 'It cannot be linked with this type {typ}.' + ) + register_variable_name_type_pair(name, typ) + return name + + +def register_variable_name_type_pair(name, typ, overwrite = False): + """Register a pair of Linen collection name and its NNX type.""" + if not overwrite and name in VariableTypeCache: + raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' + 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.') + VariableTypeCache[name] = typ + + +# add known variable type names +register_variable_name_type_pair('params', Param) +register_variable_name_type_pair('batch_stats', BatchStat) +register_variable_name_type_pair('cache', Cache) +register_variable_name_type_pair('intermediates', Intermediate) \ No newline at end of file