-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
Hi folks, me again.
I keep playing around with nnx and seems like nnx.Optimizer, when creating optimizer state for models intended to be used with 'scan', use sharding information for original, non-stacked tensor, without taking into the account extra dimension added by vmap.
Repro script:
import jax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
class Model(nnx.Module):
def __init__(self, num_layers, rngs: nnx.Rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_linear(rngs: nnx.Rngs):
return nnx.Param(
jnp.ones((16, 16)),
sharding=("A", "B"),
mesh=mesh1,
sharding_rules=rules1,
)
self.linears = create_linear(rngs=rngs)
@nnx.jit
def init():
model = Model(num_layers=1, rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(
model,
optax.adam(learning_rate=0.001),
wrt=nnx.Param,
)
return model, optimizer
model, optimizer = init()
Output:
Traceback (most recent call last):
File "/papyrax/test_scan_axis.py", line 37, in <module>
model, optimizer = init()
^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 474, in __call__
pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 135, in __call__
out = self.f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/papyrax/test_scan_axis.py", line 30, in init
optimizer = nnx.Optimizer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
return _graph_node_meta_call(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
cls._pytree_meta_construct(node, *args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
self.__init__(*args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 88, in _check_wrt_wrapper
return f(*args, wrt=wrt, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 160, in __init__
to_opt_state(tx.init(nnx.state(model, wrt)))
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 57, in to_opt_state
tree = jax.tree.map(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 52, in _to_opt_state
opt_state = OptVariable(x.get_value(), **x.get_metadata()) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
return cls._variable_meta_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
variable = super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
value = core_spmd.shard_value(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
return _apply_sharding(value, NamedSharding(mesh, pspec))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
return jax.jit(lambda x: x, out_shardings=sharding)(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 16, 16))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Metadata
Metadata
Assignees
Labels
No labels