Skip to content

Commit 1f61d6b

Browse files
committed
Add out_sharding argument to call methods for standard layers
1 parent 5109e2c commit 1f61d6b

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

flax/nnx/nn/linear.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def bias_init_wrap(rng, shape, dtype):
246246
else:
247247
self.bias = nnx.data(None)
248248

249-
def __call__(self, inputs: Array) -> Array:
249+
def __call__(self, inputs: Array, out_sharding = None) -> Array:
250250
"""Applies a linear transformation to the inputs along multiple dimensions.
251251
252252
Args:
@@ -288,7 +288,7 @@ def __call__(self, inputs: Array) -> Array:
288288
# user custom dot_general/dot_general_cls which may not have
289289
# preferred_element_type argument to avoid breaking
290290
# existing code
291-
dot_general_kwargs = {}
291+
dot_general_kwargs = {'out_sharding': out_sharding}
292292
if self.preferred_element_type is not None:
293293
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
294294
out = dot_general(
@@ -393,7 +393,7 @@ def __init__(
393393
self.promote_dtype = promote_dtype
394394
self.preferred_element_type = preferred_element_type
395395

396-
def __call__(self, inputs: Array) -> Array:
396+
def __call__(self, inputs: Array, out_sharding = None) -> Array:
397397
"""Applies a linear transformation to the inputs along the last dimension.
398398
399399
Args:
@@ -413,6 +413,7 @@ def __call__(self, inputs: Array) -> Array:
413413
# preferred_element_type argument to avoid breaking
414414
# existing code
415415
dot_general_kwargs = {}
416+
dot_general_kwargs['out_sharding'] = out_sharding
416417
if self.preferred_element_type is not None:
417418
dot_general_kwargs["preferred_element_type"] = self.preferred_element_type
418419
y = self.dot_general(
@@ -521,7 +522,7 @@ def __init__(
521522
self.preferred_element_type = preferred_element_type
522523

523524
def __call__(
524-
self, inputs: Array, einsum_str: tp.Optional[str] = None
525+
self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding = None
525526
) -> Array:
526527
"""Applies a linear transformation to the inputs along the last dimension.
527528
@@ -557,7 +558,7 @@ def __call__(
557558
# user custom self.einsum_op method which may not have
558559
# preferred_element_type argument to avoid breaking
559560
# existing code
560-
einsum_op_kwargs = {}
561+
einsum_op_kwargs = {'out_sharding': out_sharding}
561562
if self.preferred_element_type is not None:
562563
einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type
563564

@@ -1141,7 +1142,7 @@ def maybe_broadcast(
11411142
rhs_dilation=kernel_dilation,
11421143
transpose_kernel=self.transpose_kernel,
11431144
precision=self.precision,
1144-
preferred_element_type=self.preferred_element_type,
1145+
preferred_element_type=self.preferred_element_type
11451146
)
11461147

11471148
if self.padding == 'CIRCULAR':

flax/nnx/nn/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __init__(
209209
b_metadata=b_metadata,
210210
)
211211

212-
def __call__(self, x: jax.Array):
213-
y = super().__call__(x)
212+
def __call__(self, x: jax.Array, out_sharding = None):
213+
y = super().__call__(x, out_sharding=out_sharding)
214214
y += self.lora(x)
215215
return y

tests/nnx/spmd_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax import nnx
2121
import jax
2222
import jax.numpy as jnp
23-
from jax.sharding import PartitionSpec as P, NamedSharding
23+
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard
2424
import optax
2525

2626

@@ -211,6 +211,16 @@ def test_eager_sharding_context(self, use_eager_sharding):
211211
else:
212212
assert not has_sharding_spec(w)
213213

214+
def test_out_sharding(self):
215+
mesh = jax.make_mesh((2, 2), ("X", "Y"),
216+
axis_types=(AxisType.Explicit, AxisType.Explicit))
217+
with jax.set_mesh(mesh):
218+
replicated_array = jnp.arange(4).reshape(2, 2)
219+
sharded_array = reshard(replicated_array, P("X", None))
220+
l = nnx.Linear(2,4, rngs=nnx.Rngs(0))
221+
assert 'float32[2@X,4]' in str(jax.typeof(l(sharded_array)))
222+
assert 'float32[2@X,4@Y]' in str(jax.typeof(l(sharded_array, out_sharding=P("X", "Y"))))
223+
214224
@parameterized.product(use_hijax=[True, False])
215225
def test_logical_rules(self, use_hijax):
216226
self.enter_context(nnx.use_hijax(use_hijax))

0 commit comments

Comments
 (0)