@@ -1072,7 +1072,7 @@ def unflatten( # type: ignore[invalid-annotation]
10721072 * ,
10731073 index_ref : IndexMap | None = None ,
10741074 outer_index_outer_ref : IndexMap | None = None ,
1075- copy_variables : bool = True ,
1075+ copy_variables : bool = False ,
10761076) -> Node :
10771077 """Unflattens a graphdef into a node with the given state.
10781078
@@ -1835,6 +1835,7 @@ def merge( # type: ignore[invalid-annotation]
18351835 _state ,
18361836 index_ref = self .index_ref ,
18371837 outer_index_outer_ref = outer_index_outer_ref ,
1838+ copy_variables = True ,
18381839 )
18391840 return node
18401841
@@ -2307,6 +2308,7 @@ def merge( # type: ignore[invalid-annotation]
23072308 state : tp .Any ,
23082309 / ,
23092310 * states : tp .Any ,
2311+ copy : bool = False ,
23102312) -> A :
23112313 """The inverse of :func:`flax.nnx.split`.
23122314
@@ -2348,6 +2350,7 @@ def merge( # type: ignore[invalid-annotation]
23482350 graphdef: A :class:`flax.nnx.GraphDef` object.
23492351 state: A :class:`flax.nnx.State` object.
23502352 *states: Additional :class:`flax.nnx.State` objects.
2353+ copy: Whether to create new copies of the Variables in the states, defaults to ``False``.
23512354 Returns:
23522355 The merged :class:`flax.nnx.Module`.
23532356 """
@@ -2357,7 +2360,7 @@ def merge( # type: ignore[invalid-annotation]
23572360 _state = state
23582361 else :
23592362 _state = _merge_to_flat_state ((state , * states ))
2360- node = unflatten (graphdef , _state )
2363+ node = unflatten (graphdef , _state , copy_variables = copy )
23612364 return node
23622365
23632366
@@ -2592,7 +2595,7 @@ def clone(node: Node) -> Node:
25922595 A deep copy of the :class:`Module` object.
25932596 """
25942597 graphdef , state = split (node )
2595- return merge (graphdef , state )
2598+ return merge (graphdef , state , copy = True )
25962599
25972600
25982601def _mutable_like (path , x ):
0 commit comments