Skip to content

Commit accc442

Browse files
authored
Merge pull request #1 from vinayakjeet/vinayakjeet-patch-1
Fixes #243
2 parents 48930aa + 4bf74f5 commit accc442

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

Diff for: src/convnets/unet.jl

+13-10
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,21 @@ Backbone of any Metalhead ResNet-like model can be used as encoder
7171
- `final`: final block as described in original paper
7272
- `fdownscale`: downscale factor
7373
"""
74-
function unet(encoder_backbone, imgdims, outplanes::Integer,
75-
final::Any = unet_final_block, fdownscale::Integer = 0)
76-
backbonelayers = collect(flatten_chains(encoder_backbone))
77-
layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block,
78-
skip_upscale = fdownscale)
74+
function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer,
75+
final::Any = unet_final_block, fdownscale::Integer = 0)
76+
backbonelayers = collect(flatten_chains(encoder_backbone))
7977

80-
outsz = Flux.outputsize(layers, imgdims)
81-
layers = Chain(layers, final(outsz[end - 1], outplanes))
78+
# Adjusting input size to include channels
79+
adjusted_imgdims = (imgdims..., inchannels, 1)
8280

83-
return layers
84-
end
81+
layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block,
82+
skip_upscale = fdownscale)
83+
84+
outsz = Flux.outputsize(layers, adjusted_imgdims)
85+
layers = Chain(layers, final(outsz[end - 1], outplanes))
8586

87+
return layers
88+
end
8689
"""
8790
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
8891
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
@@ -114,7 +117,7 @@ end
114117

115118
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
116119
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
117-
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)
120+
layers = unet(encoder_backbone, imsize, inchannels, outplanes)
118121
model = UNet(layers)
119122
if pretrain
120123
artifact_name = "UNet"

0 commit comments

Comments
 (0)