From 738837a86a226f689104accd3dc6e8381cc6101b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 3 Nov 2025 19:10:02 -0800 Subject: [PATCH] fix List and improve Pytree --- flax/nnx/graph.py | 20 +++---- flax/nnx/helpers.py | 42 ++++++-------- flax/nnx/pytreelib.py | 103 +++++++++++++++++++++------------ tests/nnx/partitioning_test.py | 3 - 4 files changed, 96 insertions(+), 72 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index c533be3db..d826c6d5a 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -296,6 +296,15 @@ def get_node_impl_for_type( else: return None +# use type-aware sorting to support int keys +def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: + key, _ = item + if isinstance(key, int): + return (0, key) + elif isinstance(key, str): + return (1, key) + else: + raise ValueError(f'Unsupported key type: {type(key)!r}') class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): _mapping: dict[HA, HB] | tp.Mapping[HA, HB] @@ -316,16 +325,7 @@ def __len__(self) -> int: return len(self._mapping) def __hash__(self) -> int: - # use type-aware sorting to support int keys - def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: - key, _ = item - if isinstance(key, int): - return (0, key) - elif isinstance(key, str): - return (1, key) - else: - raise ValueError(f'Unsupported key type: {type(key)!r}') - return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn))) + return hash(tuple(sorted(self._mapping.items(), key=_type_aware_sort))) def __eq__(self, other: tp.Any) -> bool: return ( diff --git a/flax/nnx/helpers.py b/flax/nnx/helpers.py index 5c962d8a7..434bab9d5 100644 --- a/flax/nnx/helpers.py +++ b/flax/nnx/helpers.py @@ -116,17 +116,20 @@ def __init__(self, it: tp.Iterable[A] | None = None, /): for value in it: self.append(value) - def _getattr(self, key) -> A: - return vars(self)[key] # type: ignore[unsupported-operands] + def _get_elem(self, key: int) -> A: + return getattr(self, str(key)) - def _delattr(self, key) -> None: - vars(self).pop(key) + def _set_elem(self, key: int, value: A) -> None: + setattr(self, str(key), value) + + def _del_elem(self, key: int) -> None: + delattr(self, str(key)) def __len__(self) -> int: return self._length def append(self, value: A) -> None: - self._setattr(self._length, value) + self._set_elem(self._length, value) self._length += 1 def insert(self, index: int, value: A) -> None: @@ -139,15 +142,15 @@ def insert(self, index: int, value: A) -> None: # Shift elements to the right for i in range(self._length, index, -1): - self._setattr(i, self._getattr(i - 1)) + self._set_elem(i, self._get_elem(i - 1)) # Insert the new value - self._setattr(index, value) + self._set_elem(index, value) self._length += 1 def __iter__(self) -> tp.Iterator[A]: for i in range(self._length): - yield self._getattr(i) + yield self._get_elem(i) @tp.overload def __getitem__(self, index: int) -> A: ... @@ -159,10 +162,10 @@ def __getitem__(self, index: int | slice) -> A | tp.List[A]: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') - return self._getattr(index) + return self._get_elem(index) elif isinstance(index, slice): idxs = list(range(self._length))[index] - return [self._getattr(i) for i in idxs] + return [self._get_elem(i) for i in idxs] else: raise TypeError('Invalid index type') @@ -172,7 +175,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') - self._setattr(index, value) + self._set_elem(index, value) elif isinstance(index, slice): if not isinstance(value, tp.Iterable): raise TypeError('Expected an iterable') @@ -181,7 +184,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None: if len(idxs) != len(values): raise ValueError('Length mismatch') for i, v in zip(idxs, values): - self._setattr(i, v) + self._set_elem(i, v) else: raise TypeError('Invalid index type') @@ -191,9 +194,9 @@ def __delitem__(self, index: int | slice) -> None: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') - self._delattr(index) + self._del_elem(index) for i in range(index + 1, self._length): - self._setattr(i - 1, self._getattr(i)) + self._set_elem(i - 1, self._get_elem(i)) self._length -= 1 elif isinstance(index, slice): idxs = list(range(self._length))[index] @@ -203,15 +206,8 @@ def __delitem__(self, index: int | slice) -> None: else: raise TypeError('Invalid index type') - @staticmethod - def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: - key, _ = item - if isinstance(key, int): - return (0, key) - elif isinstance(key, str): - return (1, key) - else: - raise ValueError(f'Unsupported key type: {type(key)!r}') + _pytree__has_int_keys = True + class Sequential(Module): """A Module that applies a sequence of callables. diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 53353cc39..157ddb0db 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -340,16 +340,18 @@ def _pytree_meta_construct(cls, self, *args, **kwargs): def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: node = cls.__new__(cls, *args, **kwargs) vars_obj = vars(node) - vars_obj['_pytree__state'] = PytreeState() - vars_obj['_pytree__nodes'] = cls._pytree__nodes + object.__setattr__(node, '_pytree__state', PytreeState()) + object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes) cls._pytree_meta_construct(node, *args, **kwargs) if cls._pytree__is_pytree: missing: dict[str, bool] = {} for name, value in vars(node).items(): - if name not in vars_obj['_pytree__nodes']: + if name not in node._pytree__nodes: missing[name] = is_data(value) if missing: - vars_obj['_pytree__nodes'] = vars_obj['_pytree__nodes'].update(missing) + object.__setattr__( + node, '_pytree__nodes', node._pytree__nodes.update(missing) + ) check_pytree(node) return node @@ -500,11 +502,10 @@ def _setattr(self, name, value: tp.Any) -> None: if name not in self._pytree__nodes or ( explicit and self._pytree__nodes[name] != data ): - vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data}) - if isinstance(name, str): - object.__setattr__(self, name, value) - else: - vars(self)[name] = value + object.__setattr__( + self, '_pytree__nodes', self._pytree__nodes.update({name: data}) + ) + object.__setattr__(self, name, value) def _check_value(self, key, value, new_status: AttributeStatus | None): def _has_arrays(leaves): @@ -739,20 +740,26 @@ def __getstate__(self): return vars(self).copy() def __setstate__(self, state): - vars(self).update(state) + for key, value in state.items(): + object.__setattr__(self, key, value) # ------------------------- # Pytree Definition # ------------------------- - _pytree__key_sort_fn: tp.Callable | None = None + _pytree__has_int_keys: bool = False def _pytree__flatten_with_paths(self): - obj_vars = vars(self) + obj_items = vars(self).items() + if self._pytree__has_int_keys: + obj_items = ((_maybe_int(name), value) for name, value in obj_items) + key_fn = graph._type_aware_sort + else: + key_fn = None node_attributes = self._pytree__nodes node_names: list[str] = [] node_attrs: list[tuple[tp.Any, tp.Any]] = [] static_attrs: list[tuple[str, tp.Any]] = [] - for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn): + for name, value in sorted(obj_items, key=key_fn): if name in node_attributes and node_attributes[name]: node_names.append(name) node_attrs.append(( @@ -767,12 +774,17 @@ def _pytree__flatten_with_paths(self): return node_attrs, (tuple(node_names), tuple(static_attrs)) def _pytree__flatten(self): - obj_vars = vars(self) + obj_items = vars(self).items() + if self._pytree__has_int_keys: + obj_items = ((_maybe_int(name), value) for name, value in obj_items) + key_fn = graph._type_aware_sort + else: + key_fn = None node_attributes = self._pytree__nodes node_names: list[str] = [] node_attrs: list[tp.Any] = [] static_attrs: list[tuple[str, tp.Any]] = [] - for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn): + for name, value in sorted(obj_items, key=key_fn): if name in node_attributes and node_attributes[name]: node_names.append(name) node_attrs.append(value) @@ -790,34 +802,40 @@ def _pytree__unflatten( node_names, static_attrs = static obj = object.__new__(cls) vars_obj = vars(obj) - vars_obj.update(zip(node_names, node_attrs, strict=True)) - vars_obj.update(static_attrs) + if cls._pytree__has_int_keys: + node_names = [ + str(name) if isinstance(name, int) else name for name in node_names + ] + for name, value in zip(node_names, node_attrs, strict=True): + object.__setattr__(obj, name, value) + for name, value in static_attrs: + object.__setattr__(obj, name, value) return obj # ------------------------- # Graph Definition # ------------------------- def _graph_node_flatten(self): - nodes = vars(self) - nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn) + obj_items = vars(self).items() + if self._pytree__has_int_keys: + obj_items = ((_maybe_int(name), value) for name, value in obj_items) + key_fn = graph._type_aware_sort + else: + key_fn = None + nodes = sorted(obj_items, key=key_fn) return nodes, type(self) - def _graph_node_set_key(self, key: str, value: tp.Any): - if not isinstance(key, str): - raise KeyError(f'Invalid key: {key!r}') - elif ( - hasattr(self, key) - and isinstance(variable := getattr(self, key), Variable) - and isinstance(value, Variable) - ): - variable.update_from_state(value) - else: - setattr(self, key, value) + def _graph_node_set_key(self, key, value: tp.Any): + if self._pytree__has_int_keys and isinstance(key, int): + key = str(key) + setattr(self, key, value) - def _graph_node_pop_key(self, key: str): - if not isinstance(key, str): - raise KeyError(f'Invalid key: {key!r}') - return vars(self).pop(key) + def _graph_node_pop_key(self, key): + if self._pytree__has_int_keys and isinstance(key, int): + key = str(key) + value = getattr(self, key) + delattr(self, key) + return value @staticmethod def _graph_node_create_empty(node_type: tp.Type[P]) -> P: @@ -825,10 +843,17 @@ def _graph_node_create_empty(node_type: tp.Type[P]) -> P: return node def _graph_node_clear(self): - vars(self).clear() + for name in list(vars(self)): + delattr(self, name) def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]): - vars(self).update(attributes) + if self._pytree__has_int_keys: + attributes = ( + (str(name) if isinstance(name, int) else name, value) + for name, value in attributes + ) + for name, value in attributes: + object.__setattr__(self, name, value) if tp.TYPE_CHECKING: def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ... @@ -845,3 +870,9 @@ def __init_subclass__(cls, **kwargs): f'{pytree!r} for type {cls}.' ) super().__init_subclass__(pytree=pytree, **kwargs) + +def _maybe_int(x): + try: + return int(x) + except (ValueError, TypeError): + return x \ No newline at end of file diff --git a/tests/nnx/partitioning_test.py b/tests/nnx/partitioning_test.py index 50f8d8ee7..a2c70d9c8 100644 --- a/tests/nnx/partitioning_test.py +++ b/tests/nnx/partitioning_test.py @@ -150,9 +150,6 @@ def test_get_paritition(self): d=5.0, ) - # test Variables not shared - self.assertIsNot(vars(m.a)[0], vars(m)['b']) - state = nnx.state(m, nnx.Variable) self.assertEqual(state['a'][0][...], m.a[0][...]) self.assertEqual(state['a'][1][...], m.a[1][...])