-
-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathunet.jl
129 lines (104 loc) · 4.79 KB
/
unet.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
function pixel_shuffle_icnr(inplanes, outplanes; r = 2)
return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)),
Flux.PixelShuffle(r))
end
function unet_combine_layer(inplanes, outplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...))
end
function unet_middle_block(inplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...))
end
function unet_final_block(inplanes, outplanes)
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1),
Chain(basic_conv_bn((1, 1), inplanes, outplanes)...))
end
function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes)
return Chain(SkipConnection(Chain(m_child,
pixel_shuffle_icnr(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
relu,
unet_combine_layer(inplanes + midplanes, outplanes))
end
function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0,
m_middle = _ -> (identity,))
isempty(layers) && return m_middle(sz[end - 1])
layer, layers = layers[1], layers[2:end]
outsz = Flux.outputsize(layer, sz)
does_downscale = sz[1] ÷ 2 == outsz[1]
if !does_downscale
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...)
elseif does_downscale && skip_upscale > 0
return Chain(layer,
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1,
outplanes)...)
else
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...)
outsz = Flux.outputsize(childunet, outsz)
inplanes = sz[end - 1]
midplanes = outsz[end - 1]
outplanes = isnothing(outplanes) ? inplanes : outplanes
return unet_block(Chain(layer, childunet),
inplanes, midplanes, outplanes)
end
end
"""
unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block,
fdownscale::Integer = 0)
Creates a UNet model with specified convolutional backbone.
Backbone of any Metalhead ResNet-like model can be used as encoder
([reference](https://arxiv.org/abs/1505.04597)).
# Arguments
- `encoder_backbone`: The backbone layers of specified model to be used as encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed
to instantiate a UNet with layers of resnet18 as encoder.
- `inputsize`: size of input image
- `outplanes`: number of output feature planes
- `final`: final block as described in original paper
- `fdownscale`: downscale factor
"""
function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))
# Adjusting input size to include channels
adjusted_imgdims = (imgdims..., inchannels, 1)
layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block,
skip_upscale = fdownscale)
outsz = Flux.outputsize(layers, adjusted_imgdims)
layers = Chain(layers, final(outsz[end - 1], outplanes))
return layers
end
"""
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
Creates a UNet model with an encoder built of specified backbone. By default it uses
[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder.
([reference](https://arxiv.org/abs/1505.04597)).
# Arguments
- `imsize`: size of input image
- `inchannels`: number of channels in input image
- `outplanes`: number of output feature planes.
- `encoder_backbone`: The backbone layers of specified model to be used as encoder. For
example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of
resnet18 as encoder.
- `pretrain`: Whether to load the pre-trained weights for ImageNet
!!! warning
`UNet` does not currently support pretrained weights.
See also [`Metalhead.unet`](@ref).
"""
struct UNet
layers::Any
end
@functor UNet
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
layers = unet(encoder_backbone, imsize, inchannels, outplanes)
model = UNet(layers)
if pretrain
artifact_name = "UNet"
loadpretrain!(model, artifact_name)
end
return model
end
(m::UNet)(x::AbstractArray) = m.layers(x)