Skip to content

Commit db06488

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Fix typo in unflatten docs
The docstring for `unflatten` was updated to state that `copy_variables` defaults to `False`. PiperOrigin-RevId: 800619618
1 parent 08068be commit db06488

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

flax/nnx/graph.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,15 +1081,15 @@ def unflatten( # type: ignore[invalid-annotation]
10811081
state: A State instance.
10821082
index_ref: A mapping from indexes to nodes references found during the graph
10831083
traversal, defaults to None. If not provided, a new empty dictionary is
1084-
created. This argument can be used to unflatten a sequence of (graphdef, state)
1085-
pairs that share the same index space.
1086-
index_ref_cache: A mapping from indexes to existing nodes that can be reused.
1087-
When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
1088-
object in an empty state and then filled by the unflatten process, as a result
1089-
existing graph nodes are mutated to have the new content/topology
1090-
specified by the graphdef.
1091-
copy_variables: If True (default), variables in the state will be copied onto
1092-
the new new structure, else variables will be shared.
1084+
created. This argument can be used to unflatten a sequence of (graphdef,
1085+
state) pairs that share the same index space.
1086+
index_ref_cache: A mapping from indexes to existing nodes that can be
1087+
reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to
1088+
leave the object in an empty state and then filled by the unflatten
1089+
process, as a result existing graph nodes are mutated to have the new
1090+
content/topology specified by the graphdef.
1091+
copy_variables: If True variables in the state will be copied onto the new
1092+
new structure, else variables will be shared. Default is False.
10931093
"""
10941094
if isinstance(state, (State, dict)):
10951095
leaves = _get_sorted_leaves(state)

0 commit comments

Comments
 (0)