-
|
Hello, I have two questions about how to clone things in
But why does this implement a deepcopy? Is it in Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Hey @JINKEHE
state = jax.tree.map(lambda x: x, state) # clone
|
Beta Was this translation helpful? Give feedback.
-
|
Thank you! |
Beta Was this translation helpful? Give feedback.
Hey @JINKEHE
Stateis a Pytree you can clone it usingjax.tree.map:deepcopybut do define a complete traversal of the object graph for all NNX objects and JAX pytrees. Think ofsplitasjax.tree.flattenandmergeasjax.tree.unflatten. In fact, NNX also hasnnx.graph.flattenandnnx.graph.unflattenwhich are used by split and merge under the hood.