Skip to content

Commit

Permalink
Allow UNet decoder to use multiple blocks per stage
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jan 12, 2024
1 parent 7014381 commit 37f517a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
40 changes: 33 additions & 7 deletions pytorch_toolbelt/modules/decoders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,31 @@ def __init__(
normalization: str = NORM_BATCH,
block_kwargs=None,
unet_block=None,
num_blocks_per_stage: Union[None, int, Tuple[int, ...]] = None,
):
num_stages = len(input_spec) - 1 # Number of outputs is one less than encoder layers

if upsample_kwargs is None:
upsample_kwargs = {}

if unet_block is not None:
logger.warning("unet_block argument is deprecated, use block_type instead", DeprecationWarning)
block_type = unet_block

if num_blocks_per_stage is None:
num_blocks_per_stage = 1

if isinstance(num_blocks_per_stage, int):
num_blocks_per_stage = (num_blocks_per_stage,) * num_stages

num_blocks_per_stage = tuple(num_blocks_per_stage)

if len(num_blocks_per_stage) != num_stages:
raise ValueError(f"num_blocks_per_stage must have length of {num_stages}")

if len(out_channels) != num_stages:
raise ValueError(f"decoder_features must have length of {num_stages}")

super().__init__(input_spec)
if block_kwargs is None:
block_kwargs = {
Expand All @@ -51,14 +68,9 @@ def __init__(
blocks = []
upsamples = []

num_blocks = len(input_spec) - 1 # Number of outputs is one less than encoder layers

if len(out_channels) != num_blocks:
raise ValueError(f"decoder_features must have length of {num_blocks}")

in_channels_for_upsample_block = input_spec.channels[-1]

for block_index in reversed(range(num_blocks)):
for block_index in reversed(range(num_stages)):
features_from_encoder = input_spec.channels[block_index]

scale_factor = input_spec.strides[block_index + 1] // input_spec.strides[block_index]
Expand All @@ -74,14 +86,28 @@ def __init__(

in_channels = features_from_encoder + out_channels_from_upsample_block

blocks.append(block_type(in_channels, out_channels[block_index], **block_kwargs))
stage = self._build_stage(
in_channels, out_channels[block_index], block_type, block_kwargs, num_blocks_per_stage[block_index]
)
blocks.append(stage)

in_channels_for_upsample_block = out_channels[block_index]

self.blocks = nn.ModuleList(blocks)
self.upsamples = nn.ModuleList(upsamples)
self.output_spec = FeatureMapsSpecification(channels=out_channels, strides=input_spec.strides[:-1])

def _build_stage(
self, in_channels: int, out_channels: int, block_type: Type, block_kwargs: Mapping, num_blocks: int
):
blocks = []
for _ in range(num_blocks):
blocks.append(block_type(in_channels, out_channels, **block_kwargs))
in_channels = out_channels
if num_blocks == 1:
return blocks[0]
return nn.Sequential(*blocks)

@torch.jit.unused
def get_output_spec(self) -> FeatureMapsSpecification:
return self.output_spec
Expand Down
10 changes: 5 additions & 5 deletions pytorch_toolbelt/modules/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def __init__(
self,
in_channels: int,
out_channels: int,
activation=ACT_RELU,
normalization=NORM_BATCH,
activation: str = ACT_RELU,
normalization: str = NORM_BATCH,
normalization_kwargs=None,
activation_kwargs=None,
drop_path_rate=0.0,
drop_path_rate: float = 0.0,
):
super().__init__()

Expand Down Expand Up @@ -92,6 +92,6 @@ def forward(self, x):

x = self.conv2(x)
x = self.norm2(x)
x = self.act2(x)
x = self.act2(self.drop_path(x) + residual)

return self.drop_path(x) + residual
return x

0 comments on commit 37f517a

Please sign in to comment.