@@ -71,18 +71,21 @@ Backbone of any Metalhead ResNet-like model can be used as encoder
71
71
- `final`: final block as described in original paper
72
72
- `fdownscale`: downscale factor
73
73
"""
74
- function unet (encoder_backbone, imgdims, outplanes:: Integer ,
75
- final:: Any = unet_final_block, fdownscale:: Integer = 0 )
76
- backbonelayers = collect (flatten_chains (encoder_backbone))
77
- layers = unetlayers (backbonelayers, imgdims; m_middle = unet_middle_block,
78
- skip_upscale = fdownscale)
74
+ function unet (encoder_backbone, imgdims, inchannels:: Integer , outplanes:: Integer ,
75
+ final:: Any = unet_final_block, fdownscale:: Integer = 0 )
76
+ backbonelayers = collect (flatten_chains (encoder_backbone))
79
77
80
- outsz = Flux . outputsize (layers, imgdims)
81
- layers = Chain (layers, final (outsz[ end - 1 ], outplanes) )
78
+ # Adjusting input size to include channels
79
+ adjusted_imgdims = (imgdims ... , inchannels, 1 )
82
80
83
- return layers
84
- end
81
+ layers = unetlayers (backbonelayers, adjusted_imgdims; m_middle = unet_middle_block,
82
+ skip_upscale = fdownscale)
83
+
84
+ outsz = Flux. outputsize (layers, adjusted_imgdims)
85
+ layers = Chain (layers, final (outsz[end - 1 ], outplanes))
85
86
87
+ return layers
88
+ end
86
89
"""
87
90
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
88
91
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
114
117
115
118
function UNet (imsize:: Dims{2} = (256 , 256 ), inchannels:: Integer = 3 , outplanes:: Integer = 3 ,
116
119
encoder_backbone = Metalhead. backbone (DenseNet (121 )); pretrain:: Bool = false )
117
- layers = unet (encoder_backbone, ( imsize... , inchannels, 1 ) , outplanes)
120
+ layers = unet (encoder_backbone, imsize, inchannels, outplanes)
118
121
model = UNet (layers)
119
122
if pretrain
120
123
artifact_name = " UNet"
0 commit comments