Skip to content

Commit

Permalink
TimmResnet26D
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jan 21, 2024
1 parent 37f517a commit 594dbb8
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pytorch_toolbelt/modules/encoders/timm/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"TResNetMEncoder",
"TimmResnet152D",
"TimmSEResnet152D",
"TimmResnet26D"
"TimmResnet50D",
"TimmResnet101D",
"TimmResnet200D",
Expand Down Expand Up @@ -161,6 +162,22 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
return self


class TimmResnet26D(GenericTimmEncoder):
def __init__(
self, pretrained=True, layers=None, activation=ACT_RELU, first_conv_stride_one: bool = False, **kwargs
):
from timm.models.resnet import resnet26d

act_layer = get_activation_block(activation)
encoder = resnet50d(features_only=True, pretrained=pretrained, act_layer=act_layer, **kwargs)
if first_conv_stride_one:
encoder.conv1[0].stride = (1, 1)
for info in encoder.feature_info:
info["reduction"] = info["reduction"] // 2

super().__init__(encoder, layers)


class TimmResnet50D(GenericTimmEncoder):
def __init__(
self, pretrained=True, layers=None, activation=ACT_RELU, first_conv_stride_one: bool = False, **kwargs
Expand Down

0 comments on commit 594dbb8

Please sign in to comment.