Skip to content

Commit 75f4368

Browse files
gitttt-1234gitttt-1234
authored andcommitted
Fix up block computation for swint and convnext
1 parent 9844ff7 commit 75f4368

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

sleap_nn/architectures/convnext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from functools import partial
77
from typing import Any, Callable, List, Optional, Dict, Tuple
8-
8+
import numpy as np
99
import torch
1010
from torch import nn, Tensor
1111
from torchvision.ops.misc import Conv2dNormActivation
@@ -192,7 +192,7 @@ def __init__(
192192
else:
193193
self.arch = arch_types["tiny"]
194194

195-
self.up_blocks = len(self.arch["channels"]) - 1
195+
self.up_blocks = np.log2(self.max_stride / output_stride).astype(int) - 1
196196
self.convs_per_block = convs_per_block
197197
self.stem_patch_kernel = stem_patch_kernel
198198
self.stem_patch_stride = stem_patch_stride

sleap_nn/architectures/swint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from functools import partial
77
from typing import Any, Callable, List, Optional, Dict, Tuple
8-
8+
import numpy as np
99
import torch
1010
from torch import nn
1111
from sleap_nn.architectures.encoder_decoder import Decoder
@@ -227,7 +227,7 @@ def __init__(
227227
else:
228228
self.arch = arch_types["tiny"]
229229

230-
self.up_blocks = len(self.arch["depths"]) - 1
230+
self.up_blocks = np.log2(self.max_stride / output_stride).astype(int) - 1
231231
self.convs_per_block = convs_per_block
232232
self.stem_patch_stride = stem_patch_stride
233233
self.down_blocks = len(self.arch["depths"]) - 1

0 commit comments

Comments
 (0)