From 504259e1a984540dba31b642a0eb2fd0bf38c9f8 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Tue, 9 Jan 2024 16:46:13 +0200 Subject: [PATCH] Added upsample_kwargs argument to UNetDecoder --- pytorch_toolbelt/modules/decoders/unet.py | 9 ++++++++- pytorch_toolbelt/modules/upsample.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index 2b7d3e809..e9d7f36d4 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -28,11 +28,15 @@ def __init__( out_channels: Union[Tuple[int, ...], List[int]], block_type: Union[Type[UnetBlock], Type[UnetResidualBlock]] = UnetBlock, upsample_block: Union[UpsampleLayerType, Type[AbstractResizeLayer]] = UpsampleLayerType.BILINEAR, + upsample_kwargs: Union[None, Mapping] = None, activation: str = ACT_RELU, normalization: str = NORM_BATCH, block_kwargs=None, unet_block=None, ): + if upsample_kwargs is None: + upsample_kwargs = {} + if unet_block is not None: logger.warning("unet_block argument is deprecated, use block_type instead", DeprecationWarning) block_type = unet_block @@ -59,7 +63,10 @@ def __init__( scale_factor = input_spec.strides[block_index + 1] // input_spec.strides[block_index] upsample_layer: AbstractResizeLayer = instantiate_upsample_block( - upsample_block, in_channels=in_channels_for_upsample_block, scale_factor=scale_factor + upsample_block, + in_channels=in_channels_for_upsample_block, + scale_factor=scale_factor, + **upsample_kwargs, ) upsamples.append(upsample_layer) diff --git a/pytorch_toolbelt/modules/upsample.py b/pytorch_toolbelt/modules/upsample.py index 9774b12eb..f75c94733 100644 --- a/pytorch_toolbelt/modules/upsample.py +++ b/pytorch_toolbelt/modules/upsample.py @@ -228,7 +228,7 @@ def forward(self, x: Tensor, output_size: Optional[List[int]]) -> Tensor: # ski def instantiate_upsample_block( - block: Union[str, Type[AbstractResizeLayer]], in_channels, scale_factor: int + block: Union[str, Type[AbstractResizeLayer]], in_channels, scale_factor: int, **kwargs ) -> AbstractResizeLayer: if isinstance(block, str): block = UpsampleLayerType(block) @@ -243,4 +243,4 @@ def instantiate_upsample_block( UpsampleLayerType.RESIDUAL_DECONV: ResidualDeconvolutionUpsample2d, }[block] - return block(in_channels, scale_factor=scale_factor) + return block(in_channels, scale_factor=scale_factor, **kwargs)