Skip to content

Commit

Permalink
Update unet.jl
Browse files Browse the repository at this point in the history
3rd try
  • Loading branch information
vinayakjeet authored Mar 24, 2024
1 parent 6c7cfaa commit 689c3b1
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,26 @@ return layers
end
function modify_first_conv_layer(encoder_backbone, inchannels)
for (index, layer) in enumerate(encoder_backbone.layers)
if isa(layer, Flux.Conv)
if isa(layer, Flux.Conv)
# Correctly infer the number of output channels from the layer's weight dimensions
outchannels = size(layer.weight, 1) # The first dimension for Flux.Conv weight is the number of output channels
outchannels = size(layer.weight, 1)

kernel_size = (size(layer.weight, 3), size(layer.weight, 4)) # height and width of the kernel
# Recreate the convolutional layer with the new number of input channels
kernel_size = (size(layer.weight, 3), size(layer.weight, 4))
stride = layer.stride
pad = layer.pad
activation = layer.activation
new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad)

# Create a new convolutional layer with the adjusted number of input channels
new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad, activation=activation)
# Update the layer in the backbone
encoder_backbone.layers[index] = new_conv_layer
break
break # Assume only the first Conv layer needs adjustment
end
end
return encoder_backbone
end



"""
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
Expand Down

0 comments on commit 689c3b1

Please sign in to comment.