From 4bf74f5d3b497ed33722e9a8c4e33c9088c70a44 Mon Sep 17 00:00:00 2001 From: Vinayakjeet Singh Karki <139736674+vinayakjeet@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:04:56 +0530 Subject: [PATCH] Fixes #243 Fix UNet implementation to support input with channel sizes other than 3 --- src/convnets/unet.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 9a06a040..49f44bf5 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -71,18 +71,21 @@ Backbone of any Metalhead ResNet-like model can be used as encoder - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function unet(encoder_backbone, imgdims, outplanes::Integer, - final::Any = unet_final_block, fdownscale::Integer = 0) - backbonelayers = collect(flatten_chains(encoder_backbone)) - layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, - skip_upscale = fdownscale) +function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0) +backbonelayers = collect(flatten_chains(encoder_backbone)) - outsz = Flux.outputsize(layers, imgdims) - layers = Chain(layers, final(outsz[end - 1], outplanes)) +# Adjusting input size to include channels +adjusted_imgdims = (imgdims..., inchannels, 1) - return layers -end +layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block, + skip_upscale = fdownscale) + +outsz = Flux.outputsize(layers, adjusted_imgdims) +layers = Chain(layers, final(outsz[end - 1], outplanes)) +return layers +end """ UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) @@ -114,7 +117,7 @@ end function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) - layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) + layers = unet(encoder_backbone, imsize, inchannels, outplanes) model = UNet(layers) if pretrain artifact_name = "UNet"