From 1f61d6b77ea0a9bbe0b87713446cbf2bf8580207 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 19 Nov 2025 12:49:32 -0600 Subject: [PATCH] Add out_sharding argument to call methods for standard layers --- flax/nnx/nn/linear.py | 13 +++++++------ flax/nnx/nn/lora.py | 4 ++-- tests/nnx/spmd_test.py | 12 +++++++++++- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 6be554441..4f78f8399 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -246,7 +246,7 @@ def bias_init_wrap(rng, shape, dtype): else: self.bias = nnx.data(None) - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: Array, out_sharding = None) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: @@ -288,7 +288,7 @@ def __call__(self, inputs: Array) -> Array: # user custom dot_general/dot_general_cls which may not have # preferred_element_type argument to avoid breaking # existing code - dot_general_kwargs = {} + dot_general_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: dot_general_kwargs["preferred_element_type"] = self.preferred_element_type out = dot_general( @@ -393,7 +393,7 @@ def __init__( self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: Array, out_sharding = None) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: @@ -413,6 +413,7 @@ def __call__(self, inputs: Array) -> Array: # preferred_element_type argument to avoid breaking # existing code dot_general_kwargs = {} + dot_general_kwargs['out_sharding'] = out_sharding if self.preferred_element_type is not None: dot_general_kwargs["preferred_element_type"] = self.preferred_element_type y = self.dot_general( @@ -521,7 +522,7 @@ def __init__( self.preferred_element_type = preferred_element_type def __call__( - self, inputs: Array, einsum_str: tp.Optional[str] = None + self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding = None ) -> Array: """Applies a linear transformation to the inputs along the last dimension. @@ -557,7 +558,7 @@ def __call__( # user custom self.einsum_op method which may not have # preferred_element_type argument to avoid breaking # existing code - einsum_op_kwargs = {} + einsum_op_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type @@ -1141,7 +1142,7 @@ def maybe_broadcast( rhs_dilation=kernel_dilation, transpose_kernel=self.transpose_kernel, precision=self.precision, - preferred_element_type=self.preferred_element_type, + preferred_element_type=self.preferred_element_type ) if self.padding == 'CIRCULAR': diff --git a/flax/nnx/nn/lora.py b/flax/nnx/nn/lora.py index 16a584604..d347b80d1 100644 --- a/flax/nnx/nn/lora.py +++ b/flax/nnx/nn/lora.py @@ -209,7 +209,7 @@ def __init__( b_metadata=b_metadata, ) - def __call__(self, x: jax.Array): - y = super().__call__(x) + def __call__(self, x: jax.Array, out_sharding = None): + y = super().__call__(x, out_sharding=out_sharding) y += self.lora(x) return y diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index b13334e0f..f3360d0a2 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -20,7 +20,7 @@ from flax import nnx import jax import jax.numpy as jnp -from jax.sharding import PartitionSpec as P, NamedSharding +from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard import optax @@ -211,6 +211,16 @@ def test_eager_sharding_context(self, use_eager_sharding): else: assert not has_sharding_spec(w) + def test_out_sharding(self): + mesh = jax.make_mesh((2, 2), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) + with jax.set_mesh(mesh): + replicated_array = jnp.arange(4).reshape(2, 2) + sharded_array = reshard(replicated_array, P("X", None)) + l = nnx.Linear(2,4, rngs=nnx.Rngs(0)) + assert 'float32[2@X,4]' in str(jax.typeof(l(sharded_array))) + assert 'float32[2@X,4@Y]' in str(jax.typeof(l(sharded_array, out_sharding=P("X", "Y")))) + @parameterized.product(use_hijax=[True, False]) def test_logical_rules(self, use_hijax): self.enter_context(nnx.use_hijax(use_hijax))