Skip to content

Commit 8496a5c

Browse files
talmoclaude
andcommitted
Fix skip connection channel mismatch in ConvNext/SwinT decoders
The decoder incorrectly assumed skip connection channels match computed decoder filters (refine_convs_filters). For ConvNext/SwinT, actual encoder channels differ from computed filters, causing RuntimeError during training. Changes: - Add skip_channels parameter to SimpleUpsamplingBlock - Add encoder_channels parameter to Decoder - Pass actual encoder channels from ConvNextWrapper and SwinTWrapper Fixes training with ConvNext/SwinT backbones when output_stride != 1. Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 359af3c commit 8496a5c

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

sleap_nn/architectures/convnext.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ def __init__(
281281
# Keep the block output filters the same
282282
x_in_shape = int(self.arch["channels"][-1] * filters_rate)
283283

284+
# Encoder channels for skip connections (reversed to match decoder order)
285+
# The forward pass uses enc_output[::2][::-1] for skip features
286+
encoder_channels = self.arch["channels"][::-1]
287+
284288
self.dec = Decoder(
285289
x_in_shape=x_in_shape,
286290
current_stride=self.current_stride,
@@ -293,6 +297,7 @@ def __init__(
293297
block_contraction=self.block_contraction,
294298
output_stride=self.output_stride,
295299
up_interpolate=up_interpolate,
300+
encoder_channels=encoder_channels,
296301
)
297302

298303
if len(self.dec.decoder_stack):

sleap_nn/architectures/encoder_decoder.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
See the `EncoderDecoder` base class for requirements for creating new architectures.
2626
"""
2727

28-
from typing import List, Text, Tuple, Union
28+
from typing import List, Optional, Text, Tuple, Union
2929
from collections import OrderedDict
3030
import torch
3131
from torch import nn
@@ -391,10 +391,18 @@ def __init__(
391391
transpose_convs_activation: Text = "relu",
392392
feat_concat: bool = True,
393393
prefix: Text = "",
394+
skip_channels: Optional[int] = None,
394395
) -> None:
395396
"""Initialize the class."""
396397
super().__init__()
397398

399+
# Determine skip connection channels
400+
# If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
401+
# This allows ConvNext/SwinT to specify actual encoder channels
402+
self.skip_channels = (
403+
skip_channels if skip_channels is not None else refine_convs_filters
404+
)
405+
398406
self.x_in_shape = x_in_shape
399407
self.current_stride = current_stride
400408
self.upsampling_stride = upsampling_stride
@@ -469,13 +477,13 @@ def __init__(
469477
first_conv_in_channels = refine_convs_filters
470478
else:
471479
if self.up_interpolate:
472-
# With interpolation, input is x_in_shape + feature channels
473-
# The feature channels are the same as x_in_shape since they come from the same level
474-
first_conv_in_channels = x_in_shape + refine_convs_filters
480+
# With interpolation, input is x_in_shape + skip_channels
481+
# skip_channels may differ from refine_convs_filters for ConvNext/SwinT
482+
first_conv_in_channels = x_in_shape + self.skip_channels
475483
else:
476-
# With transpose conv, input is transpose_conv_output + feature channels
484+
# With transpose conv, input is transpose_conv_output + skip_channels
477485
first_conv_in_channels = (
478-
refine_convs_filters + transpose_convs_filters
486+
self.skip_channels + transpose_convs_filters
479487
)
480488
else:
481489
if not self.feat_concat:
@@ -582,6 +590,7 @@ def __init__(
582590
block_contraction: bool = False,
583591
up_interpolate: bool = True,
584592
prefix: str = "dec",
593+
encoder_channels: Optional[List[int]] = None,
585594
) -> None:
586595
"""Initialize the class."""
587596
super().__init__()
@@ -598,6 +607,7 @@ def __init__(
598607
self.block_contraction = block_contraction
599608
self.prefix = prefix
600609
self.stride_to_filters = {}
610+
self.encoder_channels = encoder_channels
601611

602612
self.current_strides = []
603613
self.residuals = 0
@@ -624,6 +634,13 @@ def __init__(
624634

625635
next_stride = current_stride // 2
626636

637+
# Determine skip channels for this decoder block
638+
# If encoder_channels provided, use actual encoder channels
639+
# Otherwise fall back to computed filters (for UNet compatibility)
640+
skip_channels = None
641+
if encoder_channels is not None and block < len(encoder_channels):
642+
skip_channels = encoder_channels[block]
643+
627644
if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
628645
# This accounts for the case where we dont have any more down block features to concatenate with.
629646
# In this case, add a simple upsampling block with a conv layer and with no concatenation
@@ -642,6 +659,7 @@ def __init__(
642659
transpose_convs_batch_norm=False,
643660
feat_concat=False,
644661
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
662+
skip_channels=skip_channels,
645663
)
646664
)
647665
else:
@@ -659,6 +677,7 @@ def __init__(
659677
transpose_convs_filters=block_filters_out,
660678
transpose_convs_batch_norm=False,
661679
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
680+
skip_channels=skip_channels,
662681
)
663682
)
664683

sleap_nn/architectures/swint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,13 @@ def __init__(
309309
self.stem_patch_stride * (2**3) * 2
310310
) # stem_stride * down_blocks_stride * final_max_pool_stride
311311

312+
# Encoder channels for skip connections (reversed to match decoder order)
313+
# SwinT channels: embed * 2^i for each stage i, then reversed
314+
num_stages = len(self.arch["depths"])
315+
encoder_channels = [
316+
self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
317+
]
318+
312319
self.dec = Decoder(
313320
x_in_shape=block_filters,
314321
current_stride=self.current_stride,
@@ -321,6 +328,7 @@ def __init__(
321328
block_contraction=self.block_contraction,
322329
output_stride=output_stride,
323330
up_interpolate=up_interpolate,
331+
encoder_channels=encoder_channels,
324332
)
325333

326334
if len(self.dec.decoder_stack):

0 commit comments

Comments
 (0)