Skip to content

Commit

Permalink
Added upsample_kwargs argument to UNetDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jan 9, 2024
1 parent ec571e8 commit 504259e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 8 additions & 1 deletion pytorch_toolbelt/modules/decoders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ def __init__(
out_channels: Union[Tuple[int, ...], List[int]],
block_type: Union[Type[UnetBlock], Type[UnetResidualBlock]] = UnetBlock,
upsample_block: Union[UpsampleLayerType, Type[AbstractResizeLayer]] = UpsampleLayerType.BILINEAR,
upsample_kwargs: Union[None, Mapping] = None,
activation: str = ACT_RELU,
normalization: str = NORM_BATCH,
block_kwargs=None,
unet_block=None,
):
if upsample_kwargs is None:
upsample_kwargs = {}

if unet_block is not None:
logger.warning("unet_block argument is deprecated, use block_type instead", DeprecationWarning)
block_type = unet_block
Expand All @@ -59,7 +63,10 @@ def __init__(

scale_factor = input_spec.strides[block_index + 1] // input_spec.strides[block_index]
upsample_layer: AbstractResizeLayer = instantiate_upsample_block(
upsample_block, in_channels=in_channels_for_upsample_block, scale_factor=scale_factor
upsample_block,
in_channels=in_channels_for_upsample_block,
scale_factor=scale_factor,
**upsample_kwargs,
)

upsamples.append(upsample_layer)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_toolbelt/modules/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, x: Tensor, output_size: Optional[List[int]]) -> Tensor: # ski


def instantiate_upsample_block(
block: Union[str, Type[AbstractResizeLayer]], in_channels, scale_factor: int
block: Union[str, Type[AbstractResizeLayer]], in_channels, scale_factor: int, **kwargs
) -> AbstractResizeLayer:
if isinstance(block, str):
block = UpsampleLayerType(block)
Expand All @@ -243,4 +243,4 @@ def instantiate_upsample_block(
UpsampleLayerType.RESIDUAL_DECONV: ResidualDeconvolutionUpsample2d,
}[block]

return block(in_channels, scale_factor=scale_factor)
return block(in_channels, scale_factor=scale_factor, **kwargs)

0 comments on commit 504259e

Please sign in to comment.