Skip to content

jax.eval_shape modifies nnx.Module state when using nnx.vmap and nnx.scan #4520

Open
@ytsmiling

Description

@ytsmiling

jax.eval_shape unexpectedly modifies the internal state of an nnx.Module when nnx.vmap and nnx.scan are used in the model definition. Specifically, the type of the parameter values within the nnx.Module's state changes from jaxlib.xla_extension.ArrayImpl to jax._src.interpreters.partial_eval.DynamicJaxprTracer. This change occurs even though eval_shape is intended to be a side-effect-free function and only calculate shapes.

Code to reproduce:

Colab link

import jax
import jax.numpy as jnp
import nnx

class Model(nnx.Module):
  def __init__(self, rngs):
    self.stem = nnx.Linear(16, 32, rngs=rngs)
    @nnx.split_rngs(splits=3)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return nnx.Linear(32, 32, rngs=rngs)
    
    self.backbone = create_block(rngs)
    self.head = nnx.Linear(32, 10, rngs=rngs)

  def __call__(self, x):
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x: jax.Array, model: nnx.Module):
      return model(x)

    return self.head(forward(self.stem(x), self.backbone))

model = Model(nnx.Rngs(0))
print(jax.tree.map(lambda x: str(type(x)), nnx.split(model)[1]))  # Initial state
_ = jax.eval_shape(model, jax.ShapeDtypeStruct(shape=(8, 16), dtype=jnp.float32))
print(jax.tree.map(lambda x: str(type(x)), nnx.split(model)[1]))  # State after eval_shape

Output:

State({
  'backbone': {
    'bias': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    )
  },
  'head': {
    'bias': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    )
  },
  'stem': {
    'bias': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    )
  }
})
State({
  'backbone': {
    'bias': VariableState(
      type=Param,
      value="<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>"
    )
  },
  'head': {
    'bias': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    )
  },
  'stem': {
    'bias': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    ),
    'kernel': VariableState(
      type=Param,
      value="<class 'jaxlib.xla_extension.ArrayImpl'>"
    )
  }
})

Expected behavior:

jax.eval_shape should not modify the internal state of the nnx.Module. The types of the parameter values should remain jaxlib.xla_extension.ArrayImpl after calling eval_shape.

Actual behavior:

The types of the parameter values within the nnx.Module's state are changed to jax._src.interpreters.partial_eval.DynamicJaxprTracer after calling eval_shape.

Environment:

  • JAX version: 0.4.33
  • Flax version: 0.10.2
  • Python version: 3.11.11

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions