Skip to content

Commit 0e76365

Browse files
committed
Add missing line to build_tree_from_paths
1 parent 5f97f53 commit 0e76365

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flax/jax_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,6 @@ def build_tree_from_paths(paths_and_leaves: list[tuple[jax.tree_util.KeyPath, An
352352
current = current[k]
353353

354354
# Set the leaf value
355-
final_key = path[-1]
356-
current[_path_ix(final_key)] = leaf
355+
current.is_list = isinstance(path[-1], jax.tree_util.SequenceKey)
356+
current[_path_ix(path[-1])] = leaf
357357
return _to_pytree(root)

0 commit comments

Comments
 (0)