-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
The Pytree docs say:
Arrays, Variables, ArrayRefs, and nnx.Pytrees are data.
which seems to imply that assigning raw Arrays as attributes of nnx.Pytrees is supported. For most nnx operations it seems to work fine, but nnx.cached_partial fails unless all Arrays are wrapped in variables:
from flax import nnx
import jax.numpy as jnp
class Mod(nnx.Module):
def __init__(self):
self.a = jnp.arange(3)
def __call__(self, x):
return self.a * x
m = Mod()
print(m(jnp.ones(3)))
m_jit = nnx.jit(m)
print(m_jit(2 * jnp.ones(3)))
@nnx.jit
def f(mod: Mod, x):
return mod(x) + 1
print(f(m, 3 * jnp.ones(3)))
f_partial = nnx.cached_partial(f, m)
print(f_partial(4 * jnp.ones(3)))outputs
[0. 1. 2.]
[0. 2. 4.]
[1. 4. 7.]
...
File python3.13/site-packages/flax/nnx/graph.py:1796, in SplitContext.flatten(self, node, with_paths, *filters)
1793 else:
1794 paths = None
1795 leaves = [
-> 1796 variable.get_raw_value() for variable in node_static_cache.variables
1797 ]
1798 else:
1799 graphdef, flat_state = flatten(
1800 node,
1801 ref_index=self.ref_index,
1802 ref_outer_index=ref_outer_index,
1803 with_paths=with_paths,
1804 )
AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'get_raw_value'
If I wrap it in an nnx.Variable everything works fine:
class Mod(nnx.Module):
def __init__(self):
self.a = nnx.Variable(jnp.arange(3))
# ...
gives
[0. 1. 2.]
[0. 2. 4.]
[1. 4. 7.]
[1. 5. 9.]
Is this expected behavior, and is it bad practice/unsupported to have any data in the Pytree that is not wrapped in nnx.Variable? (In the real world, I ran into this issue with a @flax.struct.dataclass that had data Array fields. I registered the whole dataclass with nnx.register_data_type and had assigned it as an attribute to some nnx.Modules, getting the same error.)
Metadata
Metadata
Assignees
Labels
No labels