Skip to content

Commit 38c058b

Browse files
author
Flax Authors
committed
Merge pull request #4863 from google:nnx-obj-pytree
PiperOrigin-RevId: 789850911
2 parents 3285ba0 + a540028 commit 38c058b

31 files changed

+508
-536
lines changed

.github/workflows/flax_test.yml

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,29 +79,6 @@ jobs:
7979
- name: Test importing Flax
8080
run: |
8181
uv run python -c "import flax"
82-
test-mutable-array:
83-
name: Run MutableArray tests
84-
needs: [pre-commit, commit-count, test-import]
85-
runs-on: ubuntu-24.04-16core
86-
steps:
87-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
88-
- name: Set up Python 3.11
89-
id: setup_python
90-
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
91-
with:
92-
python-version: 3.11
93-
- name: Setup uv
94-
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
95-
with:
96-
version: "0.3.0"
97-
- name: Install dependencies
98-
run: |
99-
uv sync --extra all --extra testing --extra docs
100-
uv pip install -U git+https://github.com/jax-ml/jax.git
101-
- name: Run MutableArray tests
102-
run: |
103-
source .venv/bin/activate
104-
FLAX_MUTABLE_ARRAY=true pytest tests/nnx/mutable_array_test.py
10582
10683
tests:
10784
name: Run Tests

docs_nnx/guides/mutable_array.ipynb renamed to docs_nnx/guides/array_ref.ipynb

Lines changed: 52 additions & 52 deletions
Large diffs are not rendered by default.

docs_nnx/guides/mutable_array.md renamed to docs_nnx/guides/array_ref.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,55 @@ jupytext:
88
jupytext_version: 1.13.8
99
---
1010

11-
# Mutable Arrays (experimental)
11+
# Array Refs (experimental)
1212

1313
```{code-cell} ipython3
14+
import jax.experimental
1415
from flax import nnx
1516
import jax
1617
import jax.numpy as jnp
17-
import jax.experimental
1818
import optax
1919
```
2020

2121
## Basics
2222

2323
+++
2424

25-
### Mutable Arrays 101
25+
### Array Refs 101
2626

2727
```{code-cell} ipython3
28-
m_array = jax.experimental.mutable_array(jnp.array([1, 2, 3]))
28+
a_ref = nnx.array_ref(jnp.array([1, 2, 3]))
2929
3030
@jax.jit
31-
def increment(m_array: jax.experimental.MutableArray): # no return!
32-
array: jax.Array = m_array[...] # access
33-
m_array[...] = array + 1 # update
31+
def increment(a_ref: nnx.ArrayRef): # no return!
32+
array: jax.Array = a_ref[...] # access
33+
a_ref[...] = array + 1 # update
3434
35-
print("[1] =", m_array); increment(m_array); print("[2] =", m_array)
35+
print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref)
3636
```
3737

3838
```{code-cell} ipython3
3939
@jax.jit
4040
def inc(x):
4141
x[...] += 1
4242
43-
print(increment.lower(m_array).as_text())
43+
print(increment.lower(a_ref).as_text())
4444
```
4545

46-
### Mutable Variables
46+
### Variables Refs
4747

4848
```{code-cell} ipython3
49-
variable = nnx.Variable(jnp.array([1, 2, 3]), mutable=True)
50-
print(f"{variable.mutable = }\n")
49+
variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)
50+
print(f"{variable.has_ref = }\n")
5151
5252
print("[1] =", variable); increment(variable); print("[2] =", variable)
5353
```
5454

5555
```{code-cell} ipython3
56-
with nnx.use_mutable_arrays(True):
56+
with nnx.use_refs(True):
5757
variable = nnx.Variable(jnp.array([1, 2, 3]))
5858
59-
print(f"{variable.mutable = }")
59+
print(f"{variable.has_ref = }")
6060
```
6161

6262
### Changing Status
@@ -70,12 +70,12 @@ class Linear(nnx.Module):
7070
def __call__(self, x):
7171
return x @ self.kernel + self.bias[None]
7272
73-
model = Linear(1, 3, rngs=nnx.Rngs(0)) # without mutable arrays
74-
mutable_model = nnx.mutable(model) # convert to mutable arrays
75-
frozen_model = nnx.freeze(mutable_model) # freeze mutable arrays again
73+
model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs
74+
refs_model = nnx.to_refs(model) # convert to array refs
75+
arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays
7676
77-
print("nnx.mutable(model) =", mutable_model)
78-
print("nnx.freeze(mutable_model) =", frozen_model)
77+
print("nnx.to_refs(model) =", refs_model)
78+
print("nnx.to_arrays(refs_model) =", arrays_model)
7979
```
8080

8181
## Examples
@@ -96,7 +96,7 @@ class Block(nnx.Module):
9696
### Training Loop
9797

9898
```{code-cell} ipython3
99-
with nnx.use_mutable_arrays(True):
99+
with nnx.use_refs(True):
100100
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
101101
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
102102
@@ -107,7 +107,7 @@ def train_step(model, optimizer, x, y):
107107
model = nnx.merge(graphdef, params, nondiff)
108108
return ((model(x) - y) ** 2).mean()
109109
110-
loss, grads = jax.value_and_grad(loss_fn)(nnx.freeze(params)) # freeze MutableArrays for jax.grad
110+
loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad
111111
optimizer.update(model, grads)
112112
113113
return loss
@@ -122,7 +122,7 @@ train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
122122
def create_stack(rngs):
123123
return Block(2, 64, 2, rngs=rngs)
124124
125-
with nnx.use_mutable_arrays(True):
125+
with nnx.use_refs(True):
126126
block_stack = create_stack(nnx.Rngs(0).fork(split=8))
127127
128128
def scan_fn(x, block):
@@ -147,17 +147,17 @@ def create_model(rngs):
147147
return Block(2, 64, 3, rngs=rngs)
148148
149149
try:
150-
with nnx.use_mutable_arrays(True):
150+
with nnx.use_refs(True):
151151
model = create_model(nnx.Rngs(0))
152152
except Exception as e:
153153
print(f"Error:", e)
154154
```
155155

156156
```{code-cell} ipython3
157-
with nnx.use_mutable_arrays(False): # <-- disable mutable arrays
157+
with nnx.use_refs(False): # <-- disable array refs
158158
model = create_model(nnx.Rngs(0))
159159
160-
model = nnx.mutable(model) # convert to mutable after creation
160+
model = nnx.to_refs(model) # convert to mutable after creation
161161
162162
print("model.linear =", model.linear)
163163
```
@@ -167,7 +167,7 @@ print("model.linear =", model.linear)
167167
def create_model(rngs):
168168
return Block(2, 64, 3, rngs=rngs)
169169
170-
with nnx.use_mutable_arrays(True):
170+
with nnx.use_refs(True):
171171
model = create_model(nnx.Rngs(0))
172172
173173
print("model.linear =", model.linear)
@@ -182,7 +182,7 @@ def get_error(f, *args):
182182
except Exception as e:
183183
return f"{type(e).__name__}: {e}"
184184
185-
x = jax.experimental.mutable_array(jnp.array(0))
185+
x = nnx.array_ref(jnp.array(0))
186186
187187
@jax.jit
188188
def f(a, b):
@@ -192,12 +192,12 @@ print(get_error(f, x, x))
192192
```
193193

194194
```{code-cell} ipython3
195-
class SharedVariables(nnx.Object):
195+
class SharedVariables(nnx.Pytree):
196196
def __init__(self):
197197
self.a = nnx.Variable(jnp.array(0))
198198
self.b = self.a
199199
200-
class SharedModules(nnx.Object):
200+
class SharedModules(nnx.Pytree):
201201
def __init__(self):
202202
self.a = Linear(1, 1, rngs=nnx.Rngs(0))
203203
self.b = self.a
@@ -206,7 +206,7 @@ class SharedModules(nnx.Object):
206206
def g(pytree):
207207
...
208208
209-
with nnx.use_mutable_arrays(True):
209+
with nnx.use_refs(True):
210210
shared_variables = SharedVariables()
211211
shared_modules = SharedModules()
212212

docs_nnx/key_concepts.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
" [ 0.6772455 0.2807398 ]\n",
147147
" [ 0.16276604 0.16813846]\n",
148148
" [ 0.310975 -0.43336964]]\n",
149-
"treedef = PyTreeDef(CustomNode(Linear[(('_object__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'kernel', '_object__state', 'bias'})), ('bias_init', <function zeros at 0x117826700>), ('dot_general', <function dot_general at 0x1172aa480>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x120f45260>), ('out_features', 2), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('promote_dtype', <function promote_dtype at 0x120f45440>), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))\n"
149+
"treedef = PyTreeDef(CustomNode(Linear[(('_pytree__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'kernel', '_pytree__state', 'bias'})), ('bias_init', <function zeros at 0x117826700>), ('dot_general', <function dot_general at 0x1172aa480>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x120f45260>), ('out_features', 2), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('promote_dtype', <function promote_dtype at 0x120f45440>), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))\n"
150150
]
151151
}
152152
],

docs_nnx/key_concepts.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ linear = jax.tree.unflatten(treedef, [value for _, value in arrays])
103103
[ 0.6772455 0.2807398 ]
104104
[ 0.16276604 0.16813846]
105105
[ 0.310975 -0.43336964]]
106-
treedef = PyTreeDef(CustomNode(Linear[(('_object__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'kernel', '_object__state', 'bias'})), ('bias_init', <function zeros at 0x117826700>), ('dot_general', <function dot_general at 0x1172aa480>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x120f45260>), ('out_features', 2), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('promote_dtype', <function promote_dtype at 0x120f45440>), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))
106+
treedef = PyTreeDef(CustomNode(Linear[(('_pytree__state', 'bias', 'kernel'), (('_pytree__nodes', frozenset({'kernel', '_pytree__state', 'bias'})), ('bias_init', <function zeros at 0x117826700>), ('dot_general', <function dot_general at 0x1172aa480>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x120f45260>), ('out_features', 2), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('promote_dtype', <function promote_dtype at 0x120f45440>), ('use_bias', True)))], [CustomNode(PytreeState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))
107107

108108

109109

docs_nnx/migrating/nnx_010_to_nnx_011.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,5 @@ use the ``is_leaf`` argument to specify that NNX modules and other NNX objects s
226226
type_names = jax.tree.map(
227227
lambda x: type(x).__name__,
228228
modules,
229-
is_leaf=lambda x: isinstance(x, nnx.Object) # <-- specify that NNX objects are leaves
229+
is_leaf=lambda x: isinstance(x, nnx.Pytree) # <-- specify that NNX objects are leaves
230230
)

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SGDState(nnx.Variable):
7575
pass
7676

7777

78-
class SGD(nnx.Object):
78+
class SGD(nnx.Pytree):
7979
def __init__(self, params: nnx.State, lr, decay=0.9):
8080
def init_optimizer_state(variable: nnx.Variable):
8181
return SGDState(

examples/nnx_toy_examples/mutable_array_basic.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
# limitations under the License.
1414

1515
# %%
16-
import os
17-
18-
os.environ['FLAX_MUTABLE_ARRAY'] = 'true'
19-
2016
import jax
2117
import jax.numpy as jnp
2218
import matplotlib.pyplot as plt
2319
import numpy as np
20+
import optax
2421

2522
from flax import nnx
2623

@@ -57,24 +54,21 @@ def __call__(self, x):
5754
self.count[...] += 1
5855
return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5)
5956

60-
61-
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
57+
with nnx.use_refs(True):
58+
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
59+
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param)
6260

6361

6462
@jax.jit
65-
def train_step(model, x, y):
66-
graphdef, params, counts = nnx.pure(nnx.split(model, nnx.Param, Count))
63+
def train_step(model, optimizer, x, y):
64+
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
6765

6866
def loss_fn(params):
6967
model = nnx.merge(graphdef, params, counts)
7068
return jnp.mean((y - model(x)) ** 2)
7169

72-
grads = jax.grad(loss_fn)(nnx.freeze(params))
73-
74-
def sgd(w, g):
75-
w[...] -= 0.1 * g[...]
76-
77-
jax.tree.map(sgd, params, grads)
70+
grads = jax.grad(loss_fn)(nnx.to_arrays(params))
71+
optimizer.update(model, grads)
7872

7973

8074
@jax.jit
@@ -84,7 +78,7 @@ def test_step(model: MLP, x, y):
8478

8579
total_steps = 10_000
8680
for step, (x, y) in enumerate(dataset(32)):
87-
train_step(model, x, y)
81+
train_step(model, optimizer, x, y)
8882

8983
if step % 1000 == 0:
9084
logs = test_step(model, X, Y)

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ def dataset(batch_size):
4040
# so we use a new Pytree type as the base. The main difference with current NNX is that
4141
# attributes that contain arrays or other pytrees now need to be explicitly marked as
4242
# using `nnx.data` to be included in the pytree.
43-
#
44-
# Variable changes in a couple of ways:
45-
# * its now implements the pytree protocol
46-
# * it can only hold arrays
47-
# * it has a mutable attribute, when True it will hold a MutableArray
48-
# * [...] is used to access & mutate underlying array
4943
class Linear(nnx.Module):
5044
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
5145
self.din, self.dout = din, dout
@@ -103,7 +97,7 @@ def __call__(
10397
mean = jnp.mean(x, axis=0)
10498
var = jnp.var(x, axis=0)
10599
# ema updates
106-
# stop gradient is used until a MutableArray supports updates from grad tracers
100+
# stop gradient is used until a ArrayRef supports updates from grad tracers
107101
sg = jax.lax.stop_gradient
108102
self.mean[...] = sg(self.mu * self.mean[...] + (1 - self.mu) * mean)
109103
self.var[...] = sg(self.mu * self.var[...] + (1 - self.mu) * var)
@@ -131,7 +125,7 @@ def __init__(
131125
use_scan: bool = True,
132126
rngs: nnx.Rngs,
133127
):
134-
self.count: nnx.MutableArray = nnx.mutable_array(jnp.array(0))
128+
self.count: nnx.ArrayRef = nnx.array_ref(jnp.array(0))
135129
self.block_in = Block(din, dhidden, rngs=rngs)
136130
self.linear_out = Linear(dhidden, dout, rngs=rngs)
137131

@@ -142,9 +136,9 @@ def __init__(
142136

143137
@jax.vmap
144138
def create_block(rngs, /):
145-
return nnx.freeze(Block(dhidden, dhidden, rngs=rngs))
139+
return nnx.to_arrays(Block(dhidden, dhidden, rngs=rngs))
146140

147-
self.blocks = nnx.mutable(create_block(rngs.fork(split=num_blocks)))
141+
self.blocks = nnx.to_refs(create_block(rngs.fork(split=num_blocks)))
148142
else:
149143
self.blocks = [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
150144

@@ -175,11 +169,11 @@ class OptState(nnx.Variable): ...
175169

176170

177171
# Optimizer are an interesting case as they are inherently stateful and
178-
# pose a good use case for MutableArray. Here we implement SGD with
172+
# pose a good use case for ArrayRef. Here we implement SGD with
179173
# momentum. The optimizer receives the params as constructor arguments but doesn't
180174
# hold a reference to them, it only uses the params to initialize its state
181175
# by creating new OptState Variables that reuse the param's metadata.
182-
class SGD(nnx.Object):
176+
class SGD(nnx.Pytree):
183177
def __init__(self, params, lr: float, decay: float = 0.9):
184178
self.lr = lr
185179
self.decay = decay
@@ -205,29 +199,24 @@ def update(self, params, grads):
205199
momentum = nnx.pure(self.momentum)
206200

207201
def update_fn(
208-
param: nnx.MutableArray, momentum: nnx.MutableArray, grad: jax.Array
202+
param: nnx.ArrayRef, momentum: nnx.ArrayRef, grad: jax.Array
209203
):
210204
momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...]
211205
param[...] -= self.lr * momentum[...]
212206

213207
jax.tree.map(update_fn, params, momentum, grads)
214208

215209
# ## Training
216-
# To setup the training loop we first instantiate the model and optimizer.
217-
# Variables are immutable (only contain Arrays) by default as it can make
218-
# initialization easier, however this means we have to use 'mutable' to
219-
# create the MutableArrays that will be updated during training.
220210

221-
# activate mutable arrays
222-
with nnx.use_mutable_arrays(True):
211+
with nnx.use_refs(True):
223212
rngs = nnx.Rngs(params=0, dropout=1)
224213
model = Model(
225214
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs
226215
)
227216
optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99)
228217

229218
# Create a copy of the model structure and set its attributes to eval model.
230-
# This works because they share the underlying MutableArrays so both models
219+
# This works because they share the underlying ArrayRefs so both models
231220
# will always be in sync.
232221
eval_model = nnx.merge(*nnx.split(model))
233222
eval_model.set_attributes(use_stats=True, deterministic=True)
@@ -249,8 +238,8 @@ def loss_fn(params):
249238
return loss
250239

251240
# For the time being we have to use 'freeze' make the Variables immutable
252-
# as 'jax.grad' doesn't support MutableArrays yet.
253-
grads = jax.grad(loss_fn)(nnx.freeze(params))
241+
# as 'jax.grad' doesn't support ArrayRefs yet.
242+
grads = jax.grad(loss_fn)(nnx.to_arrays(params))
254243
# 'update' mutates the optimizer's state and the params in place
255244
# so we don't need to return anything 🚀
256245
optimizer.update(params, grads)

0 commit comments

Comments
 (0)