-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
Hey folks, seems like flax.nnx incorrectly processes empty logical rules/mesh which leads to a very weird error deeply in JAX internals which is very hard to read. I'm not 100% sure how the meaningful output in this case should look like but definitely not like this.
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import flax.linen as nn
from jax.interpreters import pxla
def get_global_mesh():
mesh_env = pxla.thread_resources.env
return mesh_env.physical_mesh
def get_global_logical_rules():
return nn.get_logical_axis_rules()
class ModuleTest(nnx.Module):
def __init__(self, dim, rngs):
mesh = get_global_mesh()
logical_rules = get_global_logical_rules()
print(f"Mesh: {mesh}")
print(f"Logical Rules: {logical_rules}")
with mesh, nn.logical_axis_rules(rules=logical_rules):
init_fn = nnx.with_partitioning(nnx.initializers.zeros, (None,), mesh=mesh, sharding_rules=logical_rules)
self.e_score_correction_bias = nnx.Variable(init_fn(rngs.params(), (dim,), jnp.float32))
module = ModuleTest(12, nnx.Rngs(params=0, moe_router=0))Output:
> python3 test_cm.py
Mesh: Mesh()
Logical Rules: ()
Traceback (most recent call last):
File "/papyrax/test_cm.py", line 29, in <module>
module = ModuleTest(12, nnx.Rngs(params=0, moe_router=0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
return _graph_node_meta_call(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
cls._pytree_meta_construct(node, *args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
self.__init__(*args, **kwargs)
File "/papyrax/test_cm.py", line 26, in __init__
self.e_score_correction_bias = nnx.Variable(init_fn(rngs.params(), (dim,), jnp.float32))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
return cls._variable_meta_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
variable = super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
value = core_spmd.shard_value(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
return _apply_sharding(value, NamedSharding(mesh, pspec))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
return jax.jit(lambda x: x, out_shardings=sharding)(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2470, in __call__
return self._call_as_normal(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2463, in _call_as_normal
return self._vectorize_call(func=func, args=vargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2553, in _vectorize_call
outputs = ufunc(*inputs)
^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'id'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Versions:
Name: flax
Version: 0.12.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <[email protected]>
License:
Location: /usr/local/lib/python3.11/dist-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: transformer_engine_jax
---
Name: jax
Version: 0.8.1
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, flax, jax-triton, jaxpp, optax, orbax-checkpoint, rax, transformer_engine_jax
vfdev-5
Metadata
Metadata
Assignees
Labels
No labels