Skip to content

nnx.cached_partial errors on jax Arrays not wrapped in nnx.Variable #5109

@am001122

Description

@am001122

The Pytree docs say:

Arrays, Variables, ArrayRefs, and nnx.Pytrees are data.

which seems to imply that assigning raw Arrays as attributes of nnx.Pytrees is supported. For most nnx operations it seems to work fine, but nnx.cached_partial fails unless all Arrays are wrapped in variables:

from flax import nnx
import jax.numpy as jnp


class Mod(nnx.Module):
    def __init__(self):
        self.a = jnp.arange(3)

    def __call__(self, x):
        return self.a * x


m = Mod()
print(m(jnp.ones(3)))

m_jit = nnx.jit(m)
print(m_jit(2 * jnp.ones(3)))


@nnx.jit
def f(mod: Mod, x):
    return mod(x) + 1


print(f(m, 3 * jnp.ones(3)))

f_partial = nnx.cached_partial(f, m)
print(f_partial(4 * jnp.ones(3)))

outputs

[0. 1. 2.]
[0. 2. 4.]
[1. 4. 7.]
...
File python3.13/site-packages/flax/nnx/graph.py:1796, in SplitContext.flatten(self, node, with_paths, *filters)
   1793   else:
   1794     paths = None
   1795     leaves = [
-> 1796       variable.get_raw_value() for variable in node_static_cache.variables
   1797     ]
   1798 else:
   1799   graphdef, flat_state = flatten(
   1800     node,
   1801     ref_index=self.ref_index,
   1802     ref_outer_index=ref_outer_index,
   1803     with_paths=with_paths,
   1804   )

AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'get_raw_value'

If I wrap it in an nnx.Variable everything works fine:

class Mod(nnx.Module):
    def __init__(self):
        self.a = nnx.Variable(jnp.arange(3))
# ...

gives

[0. 1. 2.]
[0. 2. 4.]
[1. 4. 7.]
[1. 5. 9.]

Is this expected behavior, and is it bad practice/unsupported to have any data in the Pytree that is not wrapped in nnx.Variable? (In the real world, I ran into this issue with a @flax.struct.dataclass that had data Array fields. I registered the whole dataclass with nnx.register_data_type and had assigned it as an attribute to some nnx.Modules, getting the same error.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions