diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 5c3da9ba..5586fdd1 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -88,25 +88,26 @@ return layers end function modify_first_conv_layer(encoder_backbone, inchannels) for (index, layer) in enumerate(encoder_backbone.layers) - if isa(layer, Flux.Conv) + if isa(layer, Flux.Conv) # Correctly infer the number of output channels from the layer's weight dimensions - outchannels = size(layer.weight, 1) # The first dimension for Flux.Conv weight is the number of output channels + outchannels = size(layer.weight, 1) - kernel_size = (size(layer.weight, 3), size(layer.weight, 4)) # height and width of the kernel + # Recreate the convolutional layer with the new number of input channels + kernel_size = (size(layer.weight, 3), size(layer.weight, 4)) stride = layer.stride pad = layer.pad - activation = layer.activation + new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad) - # Create a new convolutional layer with the adjusted number of input channels - new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad, activation=activation) + # Update the layer in the backbone encoder_backbone.layers[index] = new_conv_layer - break + break # Assume only the first Conv layer needs adjustment end end return encoder_backbone end + """ UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)