Skip to content

NNX incorrectly handles empty mesh/logical axis rules #5124

@qGentry

Description

@qGentry

Hey folks, seems like flax.nnx incorrectly processes empty logical rules/mesh which leads to a very weird error deeply in JAX internals which is very hard to read. I'm not 100% sure how the meaningful output in this case should look like but definitely not like this.

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import flax.linen as nn
from jax.interpreters import pxla


def get_global_mesh():
    mesh_env = pxla.thread_resources.env
    return mesh_env.physical_mesh


def get_global_logical_rules():
    return nn.get_logical_axis_rules()


class ModuleTest(nnx.Module):
    def __init__(self, dim, rngs):
        mesh = get_global_mesh()
        logical_rules = get_global_logical_rules()
        print(f"Mesh: {mesh}")
        print(f"Logical Rules: {logical_rules}")
        with mesh, nn.logical_axis_rules(rules=logical_rules):
            init_fn = nnx.with_partitioning(nnx.initializers.zeros, (None,), mesh=mesh, sharding_rules=logical_rules)
            self.e_score_correction_bias = nnx.Variable(init_fn(rngs.params(), (dim,), jnp.float32))


module = ModuleTest(12, nnx.Rngs(params=0, moe_router=0))

Output:

> python3 test_cm.py 
Mesh: Mesh()
Logical Rules: ()
Traceback (most recent call last):
  File "/papyrax/test_cm.py", line 29, in <module>
    module = ModuleTest(12, nnx.Rngs(params=0, moe_router=0))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
    self.__init__(*args, **kwargs)
  File "/papyrax/test_cm.py", line 26, in __init__
    self.e_score_correction_bias = nnx.Variable(init_fn(rngs.params(), (dim,), jnp.float32))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
    return cls._variable_meta_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
    variable = super().__call__(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
    value = core_spmd.shard_value(
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
    return _apply_sharding(value, NamedSharding(mesh, pspec))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
    return jax.jit(lambda x: x, out_shardings=sharding)(value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2470, in __call__
    return self._call_as_normal(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2463, in _call_as_normal
    return self._vectorize_call(func=func, args=vargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/lib/_function_base_impl.py", line 2553, in _vectorize_call
    outputs = ufunc(*inputs)
              ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'id'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Versions:

Name: flax
Version: 0.12.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: 
Author: 
Author-email: Flax team <[email protected]>
License: 
Location: /usr/local/lib/python3.11/dist-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: transformer_engine_jax
---
Name: jax
Version: 0.8.1
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, flax, jax-triton, jaxpp, optax, orbax-checkpoint, rax, transformer_engine_jax

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions