2929from flax .nnx .nn import dtypes , initializers
3030from flax .typing import (
3131 Dtype ,
32+ Optional ,
3233 Shape ,
3334 Initializer ,
3435 PrecisionLike ,
3839 LaxPadding ,
3940 PromoteDtypeFn ,
4041 EinsumT ,
42+ Sharding
4143)
4244
4345Array = jax .Array
@@ -246,7 +248,7 @@ def bias_init_wrap(rng, shape, dtype):
246248 else :
247249 self .bias = nnx .data (None )
248250
249- def __call__ (self , inputs : Array ) -> Array :
251+ def __call__ (self , inputs : Array , out_sharding : Optional [ Sharding ] = None ) -> Array :
250252 """Applies a linear transformation to the inputs along multiple dimensions.
251253
252254 Args:
@@ -288,7 +290,7 @@ def __call__(self, inputs: Array) -> Array:
288290 # user custom dot_general/dot_general_cls which may not have
289291 # preferred_element_type argument to avoid breaking
290292 # existing code
291- dot_general_kwargs = {}
293+ dot_general_kwargs = {'out_sharding' : out_sharding }
292294 if self .preferred_element_type is not None :
293295 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
294296 out = dot_general (
@@ -393,7 +395,7 @@ def __init__(
393395 self .promote_dtype = promote_dtype
394396 self .preferred_element_type = preferred_element_type
395397
396- def __call__ (self , inputs : Array ) -> Array :
398+ def __call__ (self , inputs : Array , out_sharding : Optional [ Sharding ] = None ) -> Array :
397399 """Applies a linear transformation to the inputs along the last dimension.
398400
399401 Args:
@@ -412,7 +414,7 @@ def __call__(self, inputs: Array) -> Array:
412414 # user custom self.dot_general method which may not have
413415 # preferred_element_type argument to avoid breaking
414416 # existing code
415- dot_general_kwargs = {}
417+ dot_general_kwargs = {'out_sharding' : out_sharding }
416418 if self .preferred_element_type is not None :
417419 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
418420 y = self .dot_general (
@@ -521,7 +523,7 @@ def __init__(
521523 self .preferred_element_type = preferred_element_type
522524
523525 def __call__ (
524- self , inputs : Array , einsum_str : tp .Optional [str ] = None
526+ self , inputs : Array , einsum_str : tp .Optional [str ] = None , out_sharding : Optional [ Sharding ] = None
525527 ) -> Array :
526528 """Applies a linear transformation to the inputs along the last dimension.
527529
@@ -557,7 +559,7 @@ def __call__(
557559 # user custom self.einsum_op method which may not have
558560 # preferred_element_type argument to avoid breaking
559561 # existing code
560- einsum_op_kwargs = {}
562+ einsum_op_kwargs = {'out_sharding' : out_sharding }
561563 if self .preferred_element_type is not None :
562564 einsum_op_kwargs ["preferred_element_type" ] = self .preferred_element_type
563565
@@ -1065,7 +1067,7 @@ def __init__(
10651067 else :
10661068 self .bias = nnx .data (None )
10671069
1068- def __call__ (self , inputs : Array ) -> Array :
1070+ def __call__ (self , inputs : Array , out_sharding : Optional [ Sharding ] = None ) -> Array :
10691071 """Applies a transposed convolution to the inputs.
10701072
10711073 Behaviour mirrors of ``jax.lax.conv_transpose``.
@@ -1142,6 +1144,7 @@ def maybe_broadcast(
11421144 transpose_kernel = self .transpose_kernel ,
11431145 precision = self .precision ,
11441146 preferred_element_type = self .preferred_element_type ,
1147+ out_sharding = out_sharding ,
11451148 )
11461149
11471150 if self .padding == 'CIRCULAR' :
0 commit comments