@@ -246,7 +246,7 @@ def bias_init_wrap(rng, shape, dtype):
246246 else :
247247 self .bias = nnx .data (None )
248248
249- def __call__ (self , inputs : Array ) -> Array :
249+ def __call__ (self , inputs : Array , out_sharding = None ) -> Array :
250250 """Applies a linear transformation to the inputs along multiple dimensions.
251251
252252 Args:
@@ -288,7 +288,7 @@ def __call__(self, inputs: Array) -> Array:
288288 # user custom dot_general/dot_general_cls which may not have
289289 # preferred_element_type argument to avoid breaking
290290 # existing code
291- dot_general_kwargs = {}
291+ dot_general_kwargs = {'out_sharding' : out_sharding }
292292 if self .preferred_element_type is not None :
293293 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
294294 out = dot_general (
@@ -393,7 +393,7 @@ def __init__(
393393 self .promote_dtype = promote_dtype
394394 self .preferred_element_type = preferred_element_type
395395
396- def __call__ (self , inputs : Array ) -> Array :
396+ def __call__ (self , inputs : Array , out_sharding = None ) -> Array :
397397 """Applies a linear transformation to the inputs along the last dimension.
398398
399399 Args:
@@ -413,6 +413,7 @@ def __call__(self, inputs: Array) -> Array:
413413 # preferred_element_type argument to avoid breaking
414414 # existing code
415415 dot_general_kwargs = {}
416+ dot_general_kwargs ['out_sharding' ] = out_sharding
416417 if self .preferred_element_type is not None :
417418 dot_general_kwargs ["preferred_element_type" ] = self .preferred_element_type
418419 y = self .dot_general (
@@ -521,7 +522,7 @@ def __init__(
521522 self .preferred_element_type = preferred_element_type
522523
523524 def __call__ (
524- self , inputs : Array , einsum_str : tp .Optional [str ] = None
525+ self , inputs : Array , einsum_str : tp .Optional [str ] = None , out_sharding = None
525526 ) -> Array :
526527 """Applies a linear transformation to the inputs along the last dimension.
527528
@@ -557,7 +558,7 @@ def __call__(
557558 # user custom self.einsum_op method which may not have
558559 # preferred_element_type argument to avoid breaking
559560 # existing code
560- einsum_op_kwargs = {}
561+ einsum_op_kwargs = {'out_sharding' : out_sharding }
561562 if self .preferred_element_type is not None :
562563 einsum_op_kwargs ["preferred_element_type" ] = self .preferred_element_type
563564
@@ -1141,7 +1142,7 @@ def maybe_broadcast(
11411142 rhs_dilation = kernel_dilation ,
11421143 transpose_kernel = self .transpose_kernel ,
11431144 precision = self .precision ,
1144- preferred_element_type = self .preferred_element_type ,
1145+ preferred_element_type = self .preferred_element_type
11451146 )
11461147
11471148 if self .padding == 'CIRCULAR' :
0 commit comments