Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def bias_init_wrap(rng, shape, dtype):
else:
self.bias = nnx.data(None)

def __call__(self, inputs: Array) -> Array:
def __call__(self, inputs: Array, out_sharding = None) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.

Args:
Expand Down Expand Up @@ -288,7 +288,7 @@ def __call__(self, inputs: Array) -> Array:
# user custom dot_general/dot_general_cls which may not have
# preferred_element_type argument to avoid breaking
# existing code
dot_general_kwargs = {}
dot_general_kwargs = {'out_sharding': out_sharding}
if self.preferred_element_type is not None:
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
out = dot_general(
Expand Down Expand Up @@ -393,7 +393,7 @@ def __init__(
self.promote_dtype = promote_dtype
self.preferred_element_type = preferred_element_type

def __call__(self, inputs: Array) -> Array:
def __call__(self, inputs: Array, out_sharding = None) -> Array:
"""Applies a linear transformation to the inputs along the last dimension.

Args:
Expand All @@ -413,6 +413,7 @@ def __call__(self, inputs: Array) -> Array:
# preferred_element_type argument to avoid breaking
# existing code
dot_general_kwargs = {}
dot_general_kwargs['out_sharding'] = out_sharding
Comment on lines 415 to +416
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update this to be done in one line?

if self.preferred_element_type is not None:
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
y = self.dot_general(
Expand Down Expand Up @@ -521,7 +522,7 @@ def __init__(
self.preferred_element_type = preferred_element_type

def __call__(
self, inputs: Array, einsum_str: tp.Optional[str] = None
self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding = None
) -> Array:
"""Applies a linear transformation to the inputs along the last dimension.

Expand Down Expand Up @@ -557,7 +558,7 @@ def __call__(
# user custom self.einsum_op method which may not have
# preferred_element_type argument to avoid breaking
# existing code
einsum_op_kwargs = {}
einsum_op_kwargs = {'out_sharding': out_sharding}
if self.preferred_element_type is not None:
einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type

Expand Down Expand Up @@ -1141,7 +1142,7 @@ def maybe_broadcast(
rhs_dilation=kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision,
preferred_element_type=self.preferred_element_type,
preferred_element_type=self.preferred_element_type
)

if self.padding == 'CIRCULAR':
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
b_metadata=b_metadata,
)

def __call__(self, x: jax.Array):
y = super().__call__(x)
def __call__(self, x: jax.Array, out_sharding = None):
y = super().__call__(x, out_sharding=out_sharding)
y += self.lora(x)
return y
12 changes: 11 additions & 1 deletion tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard
import optax


Expand Down Expand Up @@ -211,6 +211,16 @@ def test_eager_sharding_context(self, use_eager_sharding):
else:
assert not has_sharding_spec(w)

def test_out_sharding(self):
mesh = jax.make_mesh((2, 2), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
with jax.set_mesh(mesh):
replicated_array = jnp.arange(4).reshape(2, 2)
sharded_array = reshard(replicated_array, P("X", None))
l = nnx.Linear(2,4, rngs=nnx.Rngs(0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: l hard to read

assert 'float32[2@X,4]' in str(jax.typeof(l(sharded_array)))
assert 'float32[2@X,4@Y]' in str(jax.typeof(l(sharded_array, out_sharding=P("X", "Y"))))

@parameterized.product(use_hijax=[True, False])
def test_logical_rules(self, use_hijax):
self.enter_context(nnx.use_hijax(use_hijax))
Expand Down
Loading