3838 LaxPadding ,
3939 PromoteDtypeFn ,
4040 EinsumT ,
41+ Sharding
4142)
4243
4344Array = jax .Array
@@ -246,7 +247,7 @@ def bias_init_wrap(rng, shape, dtype):
246247 else :
247248 self .bias = nnx .data (None )
248249
249- def __call__ (self , inputs : Array ) -> Array :
250+ def __call__ (self , inputs : Array , out_sharding : Sharding = None ) -> Array :
250251 """Applies a linear transformation to the inputs along multiple dimensions.
251252
252253 Args:
@@ -288,7 +289,7 @@ def __call__(self, inputs: Array) -> Array:
288289 # user custom dot_general/dot_general_cls which may not have
289290 # preferred_element_type argument to avoid breaking
290291 # existing code
291- dot_general_kwargs = {}
292+ dot_general_kwargs = {'out_sharding' : out_sharding }
292293 if self .preferred_element_type is not None :
293294 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
294295 out = dot_general (
@@ -393,7 +394,7 @@ def __init__(
393394 self .promote_dtype = promote_dtype
394395 self .preferred_element_type = preferred_element_type
395396
396- def __call__ (self , inputs : Array ) -> Array :
397+ def __call__ (self , inputs : Array , out_sharding : Sharding = None ) -> Array :
397398 """Applies a linear transformation to the inputs along the last dimension.
398399
399400 Args:
@@ -412,7 +413,7 @@ def __call__(self, inputs: Array) -> Array:
412413 # user custom self.dot_general method which may not have
413414 # preferred_element_type argument to avoid breaking
414415 # existing code
415- dot_general_kwargs = {}
416+ dot_general_kwargs = {'out_sharding' : out_sharding }
416417 if self .preferred_element_type is not None :
417418 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
418419 y = self .dot_general (
@@ -521,7 +522,7 @@ def __init__(
521522 self .preferred_element_type = preferred_element_type
522523
523524 def __call__ (
524- self , inputs : Array , einsum_str : tp .Optional [str ] = None
525+ self , inputs : Array , einsum_str : tp .Optional [str ] = None , out_sharding : Sharding = None
525526 ) -> Array :
526527 """Applies a linear transformation to the inputs along the last dimension.
527528
@@ -557,7 +558,7 @@ def __call__(
557558 # user custom self.einsum_op method which may not have
558559 # preferred_element_type argument to avoid breaking
559560 # existing code
560- einsum_op_kwargs = {}
561+ einsum_op_kwargs = {'out_sharding' : out_sharding }
561562 if self .preferred_element_type is not None :
562563 einsum_op_kwargs ["preferred_element_type" ] = self .preferred_element_type
563564
@@ -1065,7 +1066,7 @@ def __init__(
10651066 else :
10661067 self .bias = nnx .data (None )
10671068
1068- def __call__ (self , inputs : Array ) -> Array :
1069+ def __call__ (self , inputs : Array , out_sharding : Sharding = None ) -> Array :
10691070 """Applies a transposed convolution to the inputs.
10701071
10711072 Behaviour mirrors of ``jax.lax.conv_transpose``.
@@ -1142,6 +1143,7 @@ def maybe_broadcast(
11421143 transpose_kernel = self .transpose_kernel ,
11431144 precision = self .precision ,
11441145 preferred_element_type = self .preferred_element_type ,
1146+ out_sharding = out_sharding ,
11451147 )
11461148
11471149 if self .padding == 'CIRCULAR' :
0 commit comments