diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 9a06a040..49f44bf5 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -71,18 +71,21 @@ Backbone of any Metalhead ResNet-like model can be used as encoder - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function unet(encoder_backbone, imgdims, outplanes::Integer, - final::Any = unet_final_block, fdownscale::Integer = 0) - backbonelayers = collect(flatten_chains(encoder_backbone)) - layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, - skip_upscale = fdownscale) +function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0) +backbonelayers = collect(flatten_chains(encoder_backbone)) - outsz = Flux.outputsize(layers, imgdims) - layers = Chain(layers, final(outsz[end - 1], outplanes)) +# Adjusting input size to include channels +adjusted_imgdims = (imgdims..., inchannels, 1) - return layers -end +layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block, + skip_upscale = fdownscale) + +outsz = Flux.outputsize(layers, adjusted_imgdims) +layers = Chain(layers, final(outsz[end - 1], outplanes)) +return layers +end """ UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) @@ -114,7 +117,7 @@ end function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) - layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) + layers = unet(encoder_backbone, imsize, inchannels, outplanes) model = UNet(layers) if pretrain artifact_name = "UNet"