Skip to content

Commit

Permalink
Fixes FluxML#243
Browse files Browse the repository at this point in the history
Fix UNet implementation to support input  with channel sizes other than 3
  • Loading branch information
vinayakjeet authored Mar 22, 2024
1 parent 48930aa commit 4bf74f5
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,21 @@ Backbone of any Metalhead ResNet-like model can be used as encoder
- `final`: final block as described in original paper
- `fdownscale`: downscale factor
"""
function unet(encoder_backbone, imgdims, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))
layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block,
skip_upscale = fdownscale)
function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))

outsz = Flux.outputsize(layers, imgdims)
layers = Chain(layers, final(outsz[end - 1], outplanes))
# Adjusting input size to include channels
adjusted_imgdims = (imgdims..., inchannels, 1)

return layers
end
layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block,
skip_upscale = fdownscale)

outsz = Flux.outputsize(layers, adjusted_imgdims)
layers = Chain(layers, final(outsz[end - 1], outplanes))

return layers
end
"""
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
Expand Down Expand Up @@ -114,7 +117,7 @@ end

function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)
layers = unet(encoder_backbone, imsize, inchannels, outplanes)
model = UNet(layers)
if pretrain
artifact_name = "UNet"
Expand Down

0 comments on commit 4bf74f5

Please sign in to comment.