Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ShuffleNet model #258

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ include("mixers/gmlp.jl")
# ViTs
include("vit-based/vit.jl")

## ShuffleNet
include("convnets/shufflenet.jl")

# Load pretrained weights
include("pretrain.jl")

Expand All @@ -81,7 +84,7 @@ export AlexNet, VGG, ResNet, WideResNet, ResNeXt, DenseNet,
SEResNet, SEResNeXt, Res2Net, Res2NeXt,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt,
MLPMixer, ResMLP, gMLP, ViT, UNet
MLPMixer, ResMLP, gMLP, ViT, UNet, ShuffleNet

# useful for feature extraction
export backbone, classifier
Expand All @@ -92,7 +95,7 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt,
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet)
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet, :ShuffleNet)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
200 changes: 200 additions & 0 deletions src/convnets/shufflenet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
using Flux, Metalhead, MLUtils, Functors

"""
channel_shuffle(channels, groups)

Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
([reference](https://arxiv.org/abs/1707.01083)).

# Arguments

- `channels`: number of channels
- `groups`: number of groups
"""
function channel_shuffle(x::AbstractArray{Float32, 4}, g::Int)
width, height, channels, batch = size(x)
channels_per_group = channels ÷ g
if channels % g == 0
x = reshape(x, (width, height, g, channels_per_group, batch))
x = permutedims(x, (1, 2, 4, 3, 5))
x = reshape(x, (width, height, channels, batch))
end
return x
end

"""
ShuffleUnit(in_channels::Integer, out_channels::Integer, grps::Integer, downsample::Bool, ignore_group::Bool)

Shuffle Unit from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
([reference](https://arxiv.org/abs/1707.01083)).

# Arguments

- `in_channels`: number of input channels
- `out_channels`: number of output channels
- `groups`: number of groups
- `downsample`: apply downsaple if true
- `ignore_group`: ignore group convolution if true
"""
function ShuffleUnit(in_channels::Integer, out_channels::Integer,
groups::Integer, downsample::Bool, ignore_group::Bool)
mid_channels = out_channels ÷ 4
groups = ignore_group ? 1 : groups
strd = downsample ? 2 : 1

if downsample
out_channels -= in_channels
end

m = Chain(Conv((1, 1), in_channels => mid_channels; groups, pad = SamePad()),
BatchNorm(mid_channels, NNlib.relu),
Base.Fix2(channel_shuffle, groups),
DepthwiseConv((3, 3), mid_channels => mid_channels;
bias = false, stride = strd, pad = SamePad()),
BatchNorm(mid_channels, NNlib.relu),
Conv((1, 1), mid_channels => out_channels; groups, pad = SamePad()),
BatchNorm(out_channels, NNlib.relu))

if downsample
m = Parallel(
Metalhead.cat_channels, m, MeanPool((3, 3); pad = SamePad(), stride = 2))
else
m = SkipConnection(m, +)
end
return m
end

"""
create_shufflenet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3)

ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
([reference](https://arxiv.org/abs/1707.01083)).

# Arguments

- `channels`: list of channels per layer
- `init_block_channels`: number of output channels from the first layer
- `groups`: number of groups
- `num_classes`: number of classes
- `in_channels`: number of input channels
"""
function create_shufflenet(
channels, init_block_channels::Integer, groups::Integer,
num_classes::Integer; in_channels::Integer = 3)
features = []

append!(features,
[Conv((3, 3), in_channels => init_block_channels; stride = 2, pad = SamePad()),
BatchNorm(init_block_channels, NNlib.relu),
MaxPool((3, 3); stride = 2, pad = SamePad())])

in_channels::Integer = init_block_channels

for (i, num_channels) in enumerate(channels)
stage = []
for (j, out_channels) in enumerate(num_channels)
downsample = j == 1
ignore_group = i == 1 && j == 1
out_ch::Integer = trunc(out_channels)
push!(stage, ShuffleUnit(in_channels, out_ch, groups, downsample, ignore_group))
in_channels = out_ch
end
append!(features, stage)
end

model = Chain(features...)
classifier = Chain(GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes))

return Chain(model, classifier)
end

"""
shufflenet(groups, width_scale, num_classes; in_channels=3)

Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
([reference](https://arxiv.org/abs/1707.01083)).

# Arguments

- `groups`: number of groups
- `width_scale`: scaling factor for number of channels
- `num_classes`: number of classes
- `in_channels`: number of input channels
"""

function shufflenet(groups::Integer = 1, width_scale::Real = 1;
num_classes::Integer = 1000, in_channels::Integer = 3)
init_block_channels = 24
nlayers = [4, 8, 4]

if groups == 1
channels_per_layers = [144, 288, 576]
elseif groups == 2
channels_per_layers = [200, 400, 800]
elseif groups == 3
channels_per_layers = [240, 480, 960]
elseif groups == 4
channels_per_layers = [272, 544, 1088]
elseif groups == 8
channels_per_layers = [384, 768, 1536]
else
return error("The number of groups is not supported. Groups = ", groups)
end

channels = []
for i in eachindex(nlayers)
char = [channels_per_layers[i]]
new = repeat(char, nlayers[i])
push!(channels, new)
end

if width_scale != 1.0
channels = channels * width_scale

init_block_channels::Integer = trunc(init_block_channels * width_scale)
end

net = create_shufflenet(
channels,
init_block_channels,
groups,
num_classes;
in_channels)

return net
end

"""
ShuffleNet(groups, width_scale, num_classes; in_channels=3)

Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
([reference](https://arxiv.org/abs/1707.01083)).

# Arguments

- `groups`: number of groups
- `width_scale`: scaling factor for number of channels
- `num_classes`: number of classes
- `in_channels`: number of input channels
"""
struct ShuffleNet
layers::Any
end
@functor ShuffleNet

function ShuffleNet(groups::Integer = 1, width_scale::Real = 1;
num_classes::Integer = 1000, in_channels::Integer = 3)
layers = shufflenet(groups, width_scale; num_classes, in_channels)
model = ShuffleNet(layers)

return model
end

(m::ShuffleNet)(x) = m.layers(x)

backbone(m::ShuffleNet) = m.layers[1]
classifier(m::ShuffleNet) = m.layers[2:end]

im = rand32(224, 224, 3, 50); # a batch of 50 RGB images
m = ShuffleNet(1, 1;num_classes=10)
println(m(im) |> size)
16 changes: 14 additions & 2 deletions test/convnet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ end
[2, 2, 2, 2],
[3, 4, 6, 3],
[3, 4, 23, 3],
[3, 8, 36, 3],
[3, 8, 36, 3]
]
@testset for layers in layer_list
drop_list = [
(dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1),
(dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5),
(dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8),
(dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8)
]
@testset for drop_probs in drop_list
m = Metalhead.resnet(block_fn, layers; drop_probs...) |> gpu
Expand Down Expand Up @@ -374,3 +374,15 @@ end
@test size(model(x_256)) == (256, 256, 3, 1)
_gc()
end

@testitem "ShuffleNet" setup=[TestModels] begin
configs = TEST_FAST ? [(1, 1)] :
[(1, 1), (2, 1), (3, 1), (4, 1), (8, 1), (1, 0.75),
(3, 0.75), (1, 0.5), (3, 0.5), (1, 0.25), (3, 0.25)]
@testset for (groups, width_scale) in configs
m = ShuffleNet(groups, width_scale) |> gpu
@test size(m(x_224)) == (1000, 1)
@test gradtest(m, x_224)
_gc()
end
end
Loading