Skip to content

Commit 96f0a64

Browse files
committed
Add out_sharding argument to call methods for standard layers
1 parent 5109e2c commit 96f0a64

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

flax/nnx/nn/linear.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from flax.nnx.nn import dtypes, initializers
3030
from flax.typing import (
3131
Dtype,
32+
Optional,
3233
Shape,
3334
Initializer,
3435
PrecisionLike,
@@ -38,6 +39,7 @@
3839
LaxPadding,
3940
PromoteDtypeFn,
4041
EinsumT,
42+
Sharding
4143
)
4244

4345
Array = jax.Array
@@ -246,7 +248,7 @@ def bias_init_wrap(rng, shape, dtype):
246248
else:
247249
self.bias = nnx.data(None)
248250

249-
def __call__(self, inputs: Array) -> Array:
251+
def __call__(self, inputs: Array, out_sharding: Optional[Sharding] = None) -> Array:
250252
"""Applies a linear transformation to the inputs along multiple dimensions.
251253
252254
Args:
@@ -288,7 +290,7 @@ def __call__(self, inputs: Array) -> Array:
288290
# user custom dot_general/dot_general_cls which may not have
289291
# preferred_element_type argument to avoid breaking
290292
# existing code
291-
dot_general_kwargs = {}
293+
dot_general_kwargs = {'out_sharding': out_sharding}
292294
if self.preferred_element_type is not None:
293295
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
294296
out = dot_general(
@@ -393,7 +395,7 @@ def __init__(
393395
self.promote_dtype = promote_dtype
394396
self.preferred_element_type = preferred_element_type
395397

396-
def __call__(self, inputs: Array) -> Array:
398+
def __call__(self, inputs: Array, out_sharding: Optional[Sharding] = None) -> Array:
397399
"""Applies a linear transformation to the inputs along the last dimension.
398400
399401
Args:
@@ -412,7 +414,7 @@ def __call__(self, inputs: Array) -> Array:
412414
# user custom self.dot_general method which may not have
413415
# preferred_element_type argument to avoid breaking
414416
# existing code
415-
dot_general_kwargs = {}
417+
dot_general_kwargs = {'out_sharding': out_sharding}
416418
if self.preferred_element_type is not None:
417419
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
418420
y = self.dot_general(
@@ -521,7 +523,7 @@ def __init__(
521523
self.preferred_element_type = preferred_element_type
522524

523525
def __call__(
524-
self, inputs: Array, einsum_str: tp.Optional[str] = None
526+
self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding: Optional[Sharding] = None
525527
) -> Array:
526528
"""Applies a linear transformation to the inputs along the last dimension.
527529
@@ -557,7 +559,7 @@ def __call__(
557559
# user custom self.einsum_op method which may not have
558560
# preferred_element_type argument to avoid breaking
559561
# existing code
560-
einsum_op_kwargs = {}
562+
einsum_op_kwargs = {'out_sharding': out_sharding}
561563
if self.preferred_element_type is not None:
562564
einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type
563565

@@ -1065,7 +1067,7 @@ def __init__(
10651067
else:
10661068
self.bias = nnx.data(None)
10671069

1068-
def __call__(self, inputs: Array) -> Array:
1070+
def __call__(self, inputs: Array, out_sharding: Optional[Sharding] = None) -> Array:
10691071
"""Applies a transposed convolution to the inputs.
10701072
10711073
Behaviour mirrors of ``jax.lax.conv_transpose``.
@@ -1142,6 +1144,7 @@ def maybe_broadcast(
11421144
transpose_kernel=self.transpose_kernel,
11431145
precision=self.precision,
11441146
preferred_element_type=self.preferred_element_type,
1147+
out_sharding=out_sharding,
11451148
)
11461149

11471150
if self.padding == 'CIRCULAR':

flax/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class Out(Generic[T]):
132132
LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name
133133
PartitionSpecPytree = Any # pylint: disable=invalid-name
134134

135-
Sharding = tuple[AxisName, ...]
135+
Sharding = Union[tuple[AxisName, ...], jax.sharding.PartitionSpec, jax.sharding.Sharding]
136136

137137
A = TypeVar('A')
138138

@@ -233,4 +233,4 @@ def from_any(cls, x):
233233
class PromoteDtypeFn(Protocol):
234234
def __call__(
235235
self, args: TupleArg, /, *, dtype: Any = None, inexact: bool = True
236-
) -> TupleArg: ...
236+
) -> TupleArg: ...

tests/nnx/spmd_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax import nnx
2121
import jax
2222
import jax.numpy as jnp
23-
from jax.sharding import PartitionSpec as P, NamedSharding
23+
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard
2424
import optax
2525

2626

@@ -211,6 +211,16 @@ def test_eager_sharding_context(self, use_eager_sharding):
211211
else:
212212
assert not has_sharding_spec(w)
213213

214+
def test_out_sharding(self):
215+
mesh = jax.make_mesh((2, 2), ("X", "Y"),
216+
axis_types=(AxisType.Explicit, AxisType.Explicit))
217+
with jax.set_mesh(mesh):
218+
replicated_array = jnp.arange(4).reshape(2, 2)
219+
sharded_array = reshard(replicated_array, P("X", None))
220+
l = nnx.Linear(2,4, rngs=nnx.Rngs(0))
221+
assert 'float32[2@X,4]' in str(jax.typeof(l(sharded_array)))
222+
assert 'float32[2@X,4@Y]' in str(jax.typeof(l(sharded_array, out_sharding=P("X", "Y"))))
223+
214224
@parameterized.product(use_hijax=[True, False])
215225
def test_logical_rules(self, use_hijax):
216226
self.enter_context(nnx.use_hijax(use_hijax))

0 commit comments

Comments
 (0)