Skip to content

Commit e828319

Browse files
committed
merge no copy Variables
1 parent 7eba011 commit e828319

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

flax/nnx/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,4 @@ def __getattr__(name):
208208
DeprecationWarning,
209209
stacklevel=2,
210210
)
211-
if name not in globals():
212-
raise AttributeError(f"Module {__name__} has no attribute '{name}'")
213-
214-
return globals()[name]
211+
raise AttributeError(f"Module {__name__} has no attribute '{name}'")

flax/nnx/graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def unflatten( # type: ignore[invalid-annotation]
10721072
*,
10731073
index_ref: IndexMap | None = None,
10741074
outer_index_outer_ref: IndexMap | None = None,
1075-
copy_variables: bool = True,
1075+
copy_variables: bool = False,
10761076
) -> Node:
10771077
"""Unflattens a graphdef into a node with the given state.
10781078
@@ -1835,6 +1835,7 @@ def merge( # type: ignore[invalid-annotation]
18351835
_state,
18361836
index_ref=self.index_ref,
18371837
outer_index_outer_ref=outer_index_outer_ref,
1838+
copy_variables=True,
18381839
)
18391840
return node
18401841

@@ -2307,6 +2308,7 @@ def merge( # type: ignore[invalid-annotation]
23072308
state: tp.Any,
23082309
/,
23092310
*states: tp.Any,
2311+
copy: bool = False,
23102312
) -> A:
23112313
"""The inverse of :func:`flax.nnx.split`.
23122314
@@ -2348,6 +2350,7 @@ def merge( # type: ignore[invalid-annotation]
23482350
graphdef: A :class:`flax.nnx.GraphDef` object.
23492351
state: A :class:`flax.nnx.State` object.
23502352
*states: Additional :class:`flax.nnx.State` objects.
2353+
copy: Whether to create new copies of the Variables in the states, defaults to ``False``.
23512354
Returns:
23522355
The merged :class:`flax.nnx.Module`.
23532356
"""
@@ -2357,7 +2360,7 @@ def merge( # type: ignore[invalid-annotation]
23572360
_state = state
23582361
else:
23592362
_state = _merge_to_flat_state((state, *states))
2360-
node = unflatten(graphdef, _state)
2363+
node = unflatten(graphdef, _state, copy_variables=copy)
23612364
return node
23622365

23632366

@@ -2592,7 +2595,7 @@ def clone(node: Node) -> Node:
25922595
A deep copy of the :class:`Module` object.
25932596
"""
25942597
graphdef, state = split(node)
2595-
return merge(graphdef, state)
2598+
return merge(graphdef, state, copy=True)
25962599

25972600

25982601
def _mutable_like(path, x):

tests/nnx/graph_utils_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __init__(self):
292292

293293
assert isinstance(m2.a, nnx.Param)
294294
assert isinstance(state['a'], nnx.Param)
295-
assert m2.a is not state['a']
295+
assert m2.a is state['a']
296296
assert m2.a.value == state['a'].value
297297

298298
def test_shared_state_variables_shared_with_graph(self):
@@ -319,8 +319,8 @@ def __init__(self):
319319
assert isinstance(m2.a, nnx.Param)
320320
assert isinstance(m2.b, nnx.Param)
321321
assert isinstance(state['a'], nnx.Param)
322-
assert m2.a is not state['a']
323-
assert m2.b is not state['a']
322+
assert m2.a is state['a']
323+
assert m2.b is state['a']
324324
assert m2.a.value == state['a'].value
325325
assert m2.b.value == state['a'].value
326326
assert m2.a is m2.b
@@ -367,7 +367,7 @@ def __init__(self):
367367
assert isinstance(m2.tree, Tree)
368368
assert m2.tree.a.raw_value == 1
369369
assert m2.tree.b == 'a'
370-
assert m2.tree.a is not m.tree.a
370+
assert m2.tree.a is m.tree.a
371371
assert m2.tree is not m.tree
372372

373373
def test_cached_unflatten(self):

tests/nnx/integration_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def __call__(self, x):
204204
@jax.jit
205205
def train_step(params, counts, x, y):
206206
def loss_fn(params):
207-
y_pred, (_, updates) = graphdef.apply(params, counts)(x)
208-
loss = jax.numpy.mean((y_pred - y) ** 2)
209-
return loss, nnx.filter_state(updates, Count)
207+
model = nnx.merge(graphdef, params, counts, copy=True)
208+
loss = jax.numpy.mean((model(x) - y) ** 2)
209+
return loss, nnx.state(model, Count)
210210

211211
# compute gradient
212212
grads, counts = jax.grad(loss_fn, has_aux=True)(params)

tests/nnx/mutable_array_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(self):
200200
self.assertTrue(m3.a.has_ref)
201201
self.assertTrue(m3.b.has_ref)
202202
self.assertIsNot(m2, m3)
203-
self.assertIsNot(m.a, m3.a)
203+
self.assertIs(m.a, m3.a)
204204

205205
def test_freeze_duplicate_error(self):
206206
class Foo(nnx.Module):

0 commit comments

Comments
 (0)