From 205ed408d20fe52a676e9e3b16d3ba0becaed2a6 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 1 Nov 2025 12:18:48 -0700 Subject: [PATCH] add support for hijax Variables in nnx transforms --- flax/nnx/graph.py | 28 +++++++++++++++++++++++++++- tests/nnx/mutable_array_test.py | 12 ++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6fb14fa69..47e7e1dd4 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -623,6 +623,7 @@ def flatten( # type: ignore[invalid-annotation] *, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, + convert_to_lojax: bool = False, ) -> tuple[GraphDef[Node], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] @@ -632,6 +633,7 @@ def flatten( # type: ignore[invalid-annotation] with_paths: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, + convert_to_lojax: bool = False, ) -> tuple[ GraphDef[Node], FlatState[tp.Any], @@ -644,6 +646,7 @@ def flatten( # type: ignore[invalid-annotation] with_paths: tp.Literal[False], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, + convert_to_lojax: bool = False, ) -> tuple[ GraphDef[Node], list[tp.Any], @@ -656,6 +659,7 @@ def flatten( # type: ignore[invalid-annotation] with_paths: bool, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, + convert_to_lojax: bool = False, ) -> tuple[ GraphDef[Node], FlatState[tp.Any] | list[tp.Any], @@ -667,6 +671,7 @@ def flatten( # type: ignore[invalid-annotation] with_paths: bool = True, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, + convert_to_lojax: bool = False, ) -> tuple[ GraphDef[Node], FlatState[tp.Any] | list[tp.Any], @@ -700,6 +705,7 @@ def flatten( # type: ignore[invalid-annotation] attributes, leaves, paths, + convert_to_lojax, ) graphdef: GraphDef = GraphDef( nodes=nodes, attributes=attributes, num_leaves=len(leaves) @@ -721,6 +727,7 @@ def _graph_flatten( attributes: list[tuple[Key, AttrType]], leaves: list[tp.Any], paths: list[PathParts] | None, + convert_to_lojax: bool, ) -> None: is_pytree_node_ = type(node_impl) is PytreeNodeImpl @@ -777,6 +784,8 @@ def make_mutable_arraydef(value: variablelib.Ref): leaf = node # type: ignore[assignment] if inner_value is not prev_inner_value: leaf.set_raw_value(inner_value) + if convert_to_lojax and leaf.is_hijax: + leaf = variablelib._get_hijax_state(leaf) variabledef = VariableDef( type=node.var_type, # type: ignore @@ -842,6 +851,7 @@ def make_mutable_arraydef(value: variablelib.Ref): attributes, leaves, paths, + convert_to_lojax, ) elif variablelib.is_array_ref(value): attributes.append((key, MUTABLE_ARRAY_ATTR)) @@ -1092,6 +1102,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, copy_variables: bool = False, + convert_to_hijax: bool = False, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -1150,6 +1161,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref, outer_index_outer_ref, copy_variables, + convert_to_hijax, ) try: @@ -1171,6 +1183,7 @@ def _graph_unflatten( index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, copy_variables: bool, + convert_to_hijax: bool, ) -> Node: """Recursive helper for graph_unflatten. @@ -1271,6 +1284,8 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): variable = variabledef.type.from_metadata( value, dict(variabledef.metadata) ) + if convert_to_hijax and variable.is_hijax: + variable = variablelib._new_hijax_from_variable(variable) index_ref[variabledef.index] = variable return variable # type: ignore[return-value] @@ -1326,6 +1341,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: index_ref, outer_index_outer_ref, copy_variables, + convert_to_hijax, ) children.append((key, subnode)) else: @@ -1696,7 +1712,10 @@ def split( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) graphdef, flat_state = flatten( - node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index + node, + ref_index=self.ref_index, + ref_outer_index=inner_ref_outer_index, + convert_to_lojax=True, ) flat_states = _split_state(flat_state, filters) states = _to_nested_state(graphdef, flat_states) @@ -1772,6 +1791,7 @@ def flatten( # type: ignore[invalid-annotation] ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, + convert_to_lojax=True, ) if with_paths: assert isinstance(flat_state, FlatState) @@ -1801,6 +1821,7 @@ def flatten( # type: ignore[invalid-annotation] ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, + convert_to_lojax=True, ) if with_paths: assert isinstance(flat_state, FlatState) @@ -1864,6 +1885,7 @@ def merge( # type: ignore[invalid-annotation] index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, copy_variables=True, + convert_to_hijax=True, ) return node @@ -1896,6 +1918,7 @@ def unflatten( # type: ignore[invalid-annotation] graphdef, state, index_ref=self.index_ref, + convert_to_hijax=True, ) elif static_cache is not None: @@ -1938,6 +1961,7 @@ def unflatten( # type: ignore[invalid-annotation] state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, + convert_to_hijax=True, ) else: # graphdef.outer_index is None # its a new node, create it @@ -1945,6 +1969,7 @@ def unflatten( # type: ignore[invalid-annotation] graphdef, state, index_ref=self.index_ref, + convert_to_hijax=True, ) else: outer_index_outer_ref = ( @@ -1955,6 +1980,7 @@ def unflatten( # type: ignore[invalid-annotation] state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, + convert_to_hijax=True, ) return node diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index 02945e625..c0c54afb1 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -913,6 +913,18 @@ def f(v): self.assertEqual(y.shape, ()) + @nnx.use_hijax(True) + def test_nnx_jit(self): + v = nnx.Param(jnp.array([1, 2, 3])) + + @nnx.vmap(in_axes=(0,)) + def f(v): + v[...] += 1 + + f(v) + + self.assertEqual(v[...], 1) + if __name__ == '__main__': absltest.main()