@@ -8,55 +8,55 @@ jupytext:
88 jupytext_version : 1.13.8
99---
1010
11- # Mutable Arrays (experimental)
11+ # Array Refs (experimental)
1212
1313``` {code-cell} ipython3
14+ import jax.experimental
1415from flax import nnx
1516import jax
1617import jax.numpy as jnp
17- import jax.experimental
1818import optax
1919```
2020
2121## Basics
2222
2323+++
2424
25- ### Mutable Arrays 101
25+ ### Array Refs 101
2626
2727``` {code-cell} ipython3
28- m_array = jax.experimental.mutable_array (jnp.array([1, 2, 3]))
28+ a_ref = nnx.array_ref (jnp.array([1, 2, 3]))
2929
3030@jax.jit
31- def increment(m_array: jax.experimental.MutableArray ): # no return!
32- array: jax.Array = m_array [...] # access
33- m_array [...] = array + 1 # update
31+ def increment(a_ref: nnx.ArrayRef ): # no return!
32+ array: jax.Array = a_ref [...] # access
33+ a_ref [...] = array + 1 # update
3434
35- print("[1] =", m_array ); increment(m_array ); print("[2] =", m_array )
35+ print("[1] =", a_ref ); increment(a_ref ); print("[2] =", a_ref )
3636```
3737
3838``` {code-cell} ipython3
3939@jax.jit
4040def inc(x):
4141 x[...] += 1
4242
43- print(increment.lower(m_array ).as_text())
43+ print(increment.lower(a_ref ).as_text())
4444```
4545
46- ### Mutable Variables
46+ ### Variables Refs
4747
4848``` {code-cell} ipython3
49- variable = nnx.Variable(jnp.array([1, 2, 3]), mutable =True)
50- print(f"{variable.mutable = }\n")
49+ variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref =True)
50+ print(f"{variable.has_ref = }\n")
5151
5252print("[1] =", variable); increment(variable); print("[2] =", variable)
5353```
5454
5555``` {code-cell} ipython3
56- with nnx.use_mutable_arrays (True):
56+ with nnx.use_refs (True):
5757 variable = nnx.Variable(jnp.array([1, 2, 3]))
5858
59- print(f"{variable.mutable = }")
59+ print(f"{variable.has_ref = }")
6060```
6161
6262### Changing Status
@@ -70,12 +70,12 @@ class Linear(nnx.Module):
7070 def __call__(self, x):
7171 return x @ self.kernel + self.bias[None]
7272
73- model = Linear(1, 3, rngs=nnx.Rngs(0)) # without mutable arrays
74- mutable_model = nnx.mutable (model) # convert to mutable arrays
75- frozen_model = nnx.freeze(mutable_model ) # freeze mutable arrays again
73+ model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs
74+ refs_model = nnx.to_refs (model) # convert to array refs
75+ arrays_model = nnx.to_arrays(refs_model ) # convert to regular arrays
7676
77- print("nnx.mutable (model) =", mutable_model )
78- print("nnx.freeze(mutable_model ) =", frozen_model )
77+ print("nnx.to_refs (model) =", refs_model )
78+ print("nnx.to_arrays(refs_model ) =", arrays_model )
7979```
8080
8181## Examples
@@ -96,7 +96,7 @@ class Block(nnx.Module):
9696### Training Loop
9797
9898``` {code-cell} ipython3
99- with nnx.use_mutable_arrays (True):
99+ with nnx.use_refs (True):
100100 model = Block(2, 64, 3, rngs=nnx.Rngs(0))
101101 optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
102102
@@ -107,7 +107,7 @@ def train_step(model, optimizer, x, y):
107107 model = nnx.merge(graphdef, params, nondiff)
108108 return ((model(x) - y) ** 2).mean()
109109
110- loss, grads = jax.value_and_grad(loss_fn)(nnx.freeze (params)) # freeze MutableArrays for jax.grad
110+ loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays (params)) # freeze ArrayRefs for jax.grad
111111 optimizer.update(model, grads)
112112
113113 return loss
@@ -122,7 +122,7 @@ train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
122122def create_stack(rngs):
123123 return Block(2, 64, 2, rngs=rngs)
124124
125- with nnx.use_mutable_arrays (True):
125+ with nnx.use_refs (True):
126126 block_stack = create_stack(nnx.Rngs(0).fork(split=8))
127127
128128def scan_fn(x, block):
@@ -147,17 +147,17 @@ def create_model(rngs):
147147 return Block(2, 64, 3, rngs=rngs)
148148
149149try:
150- with nnx.use_mutable_arrays (True):
150+ with nnx.use_refs (True):
151151 model = create_model(nnx.Rngs(0))
152152except Exception as e:
153153 print(f"Error:", e)
154154```
155155
156156``` {code-cell} ipython3
157- with nnx.use_mutable_arrays (False): # <-- disable mutable arrays
157+ with nnx.use_refs (False): # <-- disable array refs
158158 model = create_model(nnx.Rngs(0))
159159
160- model = nnx.mutable (model) # convert to mutable after creation
160+ model = nnx.to_refs (model) # convert to mutable after creation
161161
162162print("model.linear =", model.linear)
163163```
@@ -167,7 +167,7 @@ print("model.linear =", model.linear)
167167def create_model(rngs):
168168 return Block(2, 64, 3, rngs=rngs)
169169
170- with nnx.use_mutable_arrays (True):
170+ with nnx.use_refs (True):
171171 model = create_model(nnx.Rngs(0))
172172
173173print("model.linear =", model.linear)
@@ -182,7 +182,7 @@ def get_error(f, *args):
182182 except Exception as e:
183183 return f"{type(e).__name__}: {e}"
184184
185- x = jax.experimental.mutable_array (jnp.array(0))
185+ x = nnx.array_ref (jnp.array(0))
186186
187187@jax.jit
188188def f(a, b):
@@ -192,12 +192,12 @@ print(get_error(f, x, x))
192192```
193193
194194``` {code-cell} ipython3
195- class SharedVariables(nnx.Object ):
195+ class SharedVariables(nnx.Pytree ):
196196 def __init__(self):
197197 self.a = nnx.Variable(jnp.array(0))
198198 self.b = self.a
199199
200- class SharedModules(nnx.Object ):
200+ class SharedModules(nnx.Pytree ):
201201 def __init__(self):
202202 self.a = Linear(1, 1, rngs=nnx.Rngs(0))
203203 self.b = self.a
@@ -206,7 +206,7 @@ class SharedModules(nnx.Object):
206206def g(pytree):
207207 ...
208208
209- with nnx.use_mutable_arrays (True):
209+ with nnx.use_refs (True):
210210 shared_variables = SharedVariables()
211211 shared_modules = SharedModules()
212212
0 commit comments