You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.eval_shape unexpectedly modifies the internal state of an nnx.Module when nnx.vmap and nnx.scan are used in the model definition. Specifically, the type of the parameter values within the nnx.Module's state changes from jaxlib.xla_extension.ArrayImpl to jax._src.interpreters.partial_eval.DynamicJaxprTracer. This change occurs even though eval_shape is intended to be a side-effect-free function and only calculate shapes.
jax.eval_shape should not modify the internal state of the nnx.Module. The types of the parameter values should remain jaxlib.xla_extension.ArrayImpl after calling eval_shape.
Actual behavior:
The types of the parameter values within the nnx.Module's state are changed to jax._src.interpreters.partial_eval.DynamicJaxprTracer after calling eval_shape.
Environment:
JAX version: 0.4.33
Flax version: 0.10.2
Python version: 3.11.11
The text was updated successfully, but these errors were encountered:
To clarify, mutating Modules passed as a capture will be an error again in the future, it has been briefly made valid so the JAX team could perform a refactor:
jax.eval_shape
unexpectedly modifies the internal state of annnx.Module
whennnx.vmap
andnnx.scan
are used in the model definition. Specifically, the type of the parameter values within thennx.Module
's state changes fromjaxlib.xla_extension.ArrayImpl
tojax._src.interpreters.partial_eval.DynamicJaxprTracer
. This change occurs even thougheval_shape
is intended to be a side-effect-free function and only calculate shapes.Code to reproduce:
Colab link
Output:
Expected behavior:
jax.eval_shape
should not modify the internal state of thennx.Module
. The types of the parameter values should remainjaxlib.xla_extension.ArrayImpl
after callingeval_shape
.Actual behavior:
The types of the parameter values within the
nnx.Module
's state are changed tojax._src.interpreters.partial_eval.DynamicJaxprTracer
after callingeval_shape
.Environment:
0.4.33
0.10.2
3.11.11
The text was updated successfully, but these errors were encountered: