Skip to content

Commit 1195087

Browse files
committed
Add out_sharding argument to call methods for standard layers
1 parent 5109e2c commit 1195087

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

flax/nnx/nn/linear.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
LaxPadding,
3939
PromoteDtypeFn,
4040
EinsumT,
41+
Sharding
4142
)
4243

4344
Array = jax.Array
@@ -246,7 +247,7 @@ def bias_init_wrap(rng, shape, dtype):
246247
else:
247248
self.bias = nnx.data(None)
248249

249-
def __call__(self, inputs: Array) -> Array:
250+
def __call__(self, inputs: Array, out_sharding: Sharding = None) -> Array:
250251
"""Applies a linear transformation to the inputs along multiple dimensions.
251252
252253
Args:
@@ -288,7 +289,7 @@ def __call__(self, inputs: Array) -> Array:
288289
# user custom dot_general/dot_general_cls which may not have
289290
# preferred_element_type argument to avoid breaking
290291
# existing code
291-
dot_general_kwargs = {}
292+
dot_general_kwargs = {'out_sharding': out_sharding}
292293
if self.preferred_element_type is not None:
293294
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
294295
out = dot_general(
@@ -393,7 +394,7 @@ def __init__(
393394
self.promote_dtype = promote_dtype
394395
self.preferred_element_type = preferred_element_type
395396

396-
def __call__(self, inputs: Array) -> Array:
397+
def __call__(self, inputs: Array, out_sharding: Sharding = None) -> Array:
397398
"""Applies a linear transformation to the inputs along the last dimension.
398399
399400
Args:
@@ -412,7 +413,7 @@ def __call__(self, inputs: Array) -> Array:
412413
# user custom self.dot_general method which may not have
413414
# preferred_element_type argument to avoid breaking
414415
# existing code
415-
dot_general_kwargs = {}
416+
dot_general_kwargs = {'out_sharding': out_sharding}
416417
if self.preferred_element_type is not None:
417418
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
418419
y = self.dot_general(
@@ -521,7 +522,7 @@ def __init__(
521522
self.preferred_element_type = preferred_element_type
522523

523524
def __call__(
524-
self, inputs: Array, einsum_str: tp.Optional[str] = None
525+
self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding: Sharding = None
525526
) -> Array:
526527
"""Applies a linear transformation to the inputs along the last dimension.
527528
@@ -557,7 +558,7 @@ def __call__(
557558
# user custom self.einsum_op method which may not have
558559
# preferred_element_type argument to avoid breaking
559560
# existing code
560-
einsum_op_kwargs = {}
561+
einsum_op_kwargs = {'out_sharding': out_sharding}
561562
if self.preferred_element_type is not None:
562563
einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type
563564

@@ -1065,7 +1066,7 @@ def __init__(
10651066
else:
10661067
self.bias = nnx.data(None)
10671068

1068-
def __call__(self, inputs: Array) -> Array:
1069+
def __call__(self, inputs: Array, out_sharding: Sharding = None) -> Array:
10691070
"""Applies a transposed convolution to the inputs.
10701071
10711072
Behaviour mirrors of ``jax.lax.conv_transpose``.
@@ -1142,6 +1143,7 @@ def maybe_broadcast(
11421143
transpose_kernel=self.transpose_kernel,
11431144
precision=self.precision,
11441145
preferred_element_type=self.preferred_element_type,
1146+
out_sharding=out_sharding,
11451147
)
11461148

11471149
if self.padding == 'CIRCULAR':

tests/nnx/spmd_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax import nnx
2121
import jax
2222
import jax.numpy as jnp
23-
from jax.sharding import PartitionSpec as P, NamedSharding
23+
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard
2424
import optax
2525

2626

@@ -211,6 +211,16 @@ def test_eager_sharding_context(self, use_eager_sharding):
211211
else:
212212
assert not has_sharding_spec(w)
213213

214+
def test_out_sharding(self):
215+
mesh = jax.make_mesh((2, 2), ("X", "Y"),
216+
axis_types=(AxisType.Explicit, AxisType.Explicit))
217+
with jax.set_mesh(mesh):
218+
replicated_array = jnp.arange(4).reshape(2, 2)
219+
sharded_array = reshard(replicated_array, P("X", None))
220+
l = nnx.Linear(2,4, rngs=nnx.Rngs(0))
221+
assert 'float32[2@X,4]' in str(jax.typeof(l(sharded_array)))
222+
assert 'float32[2@X,4@Y]' in str(jax.typeof(l(sharded_array, out_sharding=P("X", "Y"))))
223+
214224
@parameterized.product(use_hijax=[True, False])
215225
def test_logical_rules(self, use_hijax):
216226
self.enter_context(nnx.use_hijax(use_hijax))

0 commit comments

Comments
 (0)