diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4f3592d54..467932bee 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -25,6 +25,7 @@ from flax.typing import ( Array, ConvGeneralDilatedT, + ConvTransposeT, DotGeneralT, Dtype, Initializer, @@ -948,6 +949,7 @@ class ConvTranspose(Module): kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() transpose_kernel: bool = False + conv_transpose: ConvTransposeT = lax.conv_transpose promote_dtype: PromoteDtypeFn = promote_dtype preferred_element_type: Dtype | None = None @@ -1037,7 +1039,7 @@ def maybe_broadcast( assert inputs is not None assert kernel is not None - y = lax.conv_transpose( + y = self.conv_transpose( inputs, kernel, strides, diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 869710e3d..595cc89ef 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -34,6 +34,7 @@ PrecisionLike, DotGeneralT, ConvGeneralDilatedT, + ConvTransposeT, PaddingLike, LaxPadding, PromoteDtypeFn, @@ -969,6 +970,7 @@ def __init__( kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, transpose_kernel: bool = False, + conv_transpose: ConvTransposeT = lax.conv_transpose, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, @@ -992,6 +994,7 @@ def __init__( self.kernel_init = kernel_init self.bias_init = bias_init self.transpose_kernel = transpose_kernel + self.conv_transpose = conv_transpose self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type @@ -1081,7 +1084,7 @@ def maybe_broadcast( (inputs, kernel, bias), dtype=self.dtype ) - y = lax.conv_transpose( + y = self.conv_transpose( inputs, kernel, strides, diff --git a/flax/typing.py b/flax/typing.py index 350be2e36..7e717446c 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -68,6 +68,7 @@ def is_key_like(x: Any) -> TypeGuard[Key]: ] DotGeneralT = Callable[..., Array] ConvGeneralDilatedT = Callable[..., Array] +ConvTransposeT = Callable[..., Array] EinsumT = Callable[..., Array] PaddingLike = Union[str, int, Sequence[Union[int, tuple[int, int]]]]