Skip to content

Commit

Permalink
Update unet.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakjeet authored Mar 24, 2024
1 parent 1e892a4 commit 18478fe
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function modify_first_conv_layer_advanced(encoder_backbone, inchannels)
stride = layer.stride
pad = layer.pad

# Create a new convolutional layer with the updated input channels

new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad)
layers[index] = new_conv_layer # Replace the old layer with the new one

Expand Down Expand Up @@ -140,8 +140,7 @@ end
@functor UNet
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
# Modify the first convolutional layer of the encoder backbone to have the correct `inchannels`.
# This is a conceptual step; the actual implementation will depend on the structure of your backbone.

if inchannels != 3
encoder_backbone = modify_first_conv_layer_advanced(encoder_backbone, inchannels)

Expand All @@ -150,8 +149,7 @@ function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)
model = UNet(layers)
if pretrain
# Note: As per the original comment, pre-trained weights are not supported in this context.
# This block is left as-is from your original code for completeness.

artifact_name = "UNet"
loadpretrain!(model, artifact_name)
end
Expand Down

0 comments on commit 18478fe

Please sign in to comment.