Skip to content

nnx.Optimizer doesn't respect extra sharding axis added by nnx.scan/nnx.vmap #5112

@qGentry

Description

@qGentry

Hi folks, me again.

I keep playing around with nnx and seems like nnx.Optimizer, when creating optimizer state for models intended to be used with 'scan', use sharding information for original, non-stacked tensor, without taking into the account extra dimension added by vmap.

Repro script:

import jax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax


mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))


class Model(nnx.Module):
    def __init__(self, num_layers, rngs: nnx.Rngs):
        @nnx.split_rngs(splits=num_layers)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_linear(rngs: nnx.Rngs):
            return nnx.Param(
                jnp.ones((16, 16)), 
                sharding=("A", "B"), 
                mesh=mesh1,
                sharding_rules=rules1,
            )
        self.linears = create_linear(rngs=rngs)


@nnx.jit
def init():
    model = Model(num_layers=1, rngs=nnx.Rngs(params=0))
    optimizer = nnx.Optimizer(
        model,
        optax.adam(learning_rate=0.001),
        wrt=nnx.Param,
    )
    return model, optimizer

model, optimizer = init()

Output:

Traceback (most recent call last):
  File "/papyrax/test_scan_axis.py", line 37, in <module>
    model, optimizer = init()
                       ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 474, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
                                               ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 135, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/papyrax/test_scan_axis.py", line 30, in init
    optimizer = nnx.Optimizer(
                ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
    self.__init__(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 88, in _check_wrt_wrapper
    return f(*args, wrt=wrt, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 160, in __init__
    to_opt_state(tx.init(nnx.state(model, wrt)))
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 57, in to_opt_state
    tree = jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 52, in _to_opt_state
    opt_state = OptVariable(x.get_value(), **x.get_metadata())  # type: ignore
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
    return cls._variable_meta_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
    variable = super().__call__(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
    value = core_spmd.shard_value(
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
    return _apply_sharding(value, NamedSharding(mesh, pspec))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
    return jax.jit(lambda x: x, out_shardings=sharding)(value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 16, 16))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions