-
-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ShuffleNet model #258
base: master
Are you sure you want to change the base?
Changes from 2 commits
79dcc99
08fb6b9
367c140
9d01bf2
894ae7a
c29dab3
597aa2f
873dc51
9d91f81
f366f77
526df7a
2aae877
9261218
5cb3430
73f73ca
7093459
cfb9a24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,160 @@ | ||||||||
using Flux | ||||||||
|
||||||||
""" | ||||||||
Channelshuffle(channels, groups) | ||||||||
|
||||||||
Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices | ||||||||
([reference](https://arxiv.org/abs/1707.01083)). | ||||||||
|
||||||||
# Arguments | ||||||||
|
||||||||
- `channels`: number of channels | ||||||||
- `groups`: number of groups | ||||||||
""" | ||||||||
function ChannelShuffle(x::Array{Float32, 4}, g::Int) | ||||||||
width, height, channels, batch = size(x) | ||||||||
channels_per_group = channels÷g | ||||||||
if (channels % g) == 0 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We have a JuliaFormatter config in this repo, so make sure to run that before pushing your code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applied. |
||||||||
x = reshape(x, (width, height, g, channels_per_group, batch)) | ||||||||
x = permutedims(x,(1,2,4,3,5)) | ||||||||
x = reshape(x, (width, height, channels, batch)) | ||||||||
end | ||||||||
return x | ||||||||
end | ||||||||
|
||||||||
""" | ||||||||
ShuffleUnit(in_channels::Integer, out_channels::Integer, grps::Integer, downsample::Bool, ignore_group::Bool) | ||||||||
|
||||||||
Shuffle Unit from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices | ||||||||
([reference](https://arxiv.org/abs/1707.01083)). | ||||||||
|
||||||||
# Arguments | ||||||||
|
||||||||
- `in_channels`: number of input channels | ||||||||
- `out_channels`: number of output channels | ||||||||
- `groups`: number of groups | ||||||||
- `downsample`: apply downsaple if true | ||||||||
- `ignore_group`: ignore group convolution if true | ||||||||
""" | ||||||||
function ShuffleUnit(in_channels::Integer, out_channels::Integer, groups::Integer, downsample::Bool, ignore_group::Bool) | ||||||||
mid_channels = out_channels ÷ 4 | ||||||||
groups = ignore_group ? 1 : groups | ||||||||
strd = downsample ? 2 : 1 | ||||||||
|
||||||||
if downsample | ||||||||
out_channels -= in_channels | ||||||||
end | ||||||||
|
||||||||
m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()), | ||||||||
BatchNorm(mid_channels), | ||||||||
NNlib.relu, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed. |
||||||||
x -> ChannelShuffle(x, groups), | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Will be easier on the compiler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed. |
||||||||
DepthwiseConv((3,3), mid_channels => mid_channels; bias=false, stride=strd, pad=SamePad()), | ||||||||
BatchNorm(mid_channels), | ||||||||
NNlib.relu, | ||||||||
Conv((1,1), mid_channels => out_channels; groups, pad=SamePad()), | ||||||||
BatchNorm(out_channels), | ||||||||
NNlib.relu) | ||||||||
|
||||||||
if downsample | ||||||||
m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed. |
||||||||
else | ||||||||
m = SkipConnection(m, +) | ||||||||
end | ||||||||
return m | ||||||||
end | ||||||||
|
||||||||
""" | ||||||||
ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) | ||||||||
|
||||||||
ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices | ||||||||
([reference](https://arxiv.org/abs/1707.01083)). | ||||||||
|
||||||||
# Arguments | ||||||||
|
||||||||
- `channels`: list of channels per layer | ||||||||
- `init_block_channels`: number of output channels from the first layer | ||||||||
- `groups`: number of groups | ||||||||
- `num_classes`: number of classes | ||||||||
- `in_channels`: number of input channels | ||||||||
""" | ||||||||
function ShuffleNet(channels, init_block_channels::Integer, groups, num_classes; in_channels=3) | ||||||||
features = [] | ||||||||
|
||||||||
append!(features, [Conv((3,3), in_channels => init_block_channels; stride=2, pad=SamePad()), | ||||||||
BatchNorm(init_block_channels), | ||||||||
NNlib.relu, | ||||||||
MaxPool((3,3); stride=2, pad=SamePad())]) | ||||||||
|
||||||||
in_channels::Integer = init_block_channels | ||||||||
|
||||||||
for (i, num_channels) in enumerate(channels) | ||||||||
stage = [] | ||||||||
for (j, out_channels) in enumerate(num_channels) | ||||||||
downsample = j==1 | ||||||||
ignore_group = i==1 && j==1 | ||||||||
out_ch::Integer = trunc(out_channels) | ||||||||
push!(stage, ShuffleUnit(in_channels, out_ch, groups, downsample, ignore_group)) | ||||||||
in_channels = out_ch | ||||||||
end | ||||||||
append!(features, stage) | ||||||||
end | ||||||||
|
||||||||
model = Chain(features...) | ||||||||
|
||||||||
return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The general modus operandi of this library has been to create named types for the top-level model and wrap the underlying Chain with them. You can see this pattern in the files for any of the other exported models. For the suggestion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could I see an example? Sorry, I'm still a newbie using Julia, I looked to the rest of convnets and tried to code with a similar style, but there are still things I that still haven't fully understood. |
||||||||
end | ||||||||
|
||||||||
""" | ||||||||
shufflenet(groups, width_scale, num_classes; in_channels=3) | ||||||||
|
||||||||
Wrapper for ShuffleNet. Create a ShuffleNet model from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices | ||||||||
([reference](https://arxiv.org/abs/1707.01083)). | ||||||||
|
||||||||
# Arguments | ||||||||
|
||||||||
- `groups`: number of groups | ||||||||
- `width_scale`: scaling factor for number of channels | ||||||||
- `num_classes`: number of classes | ||||||||
- `in_channels`: number of input channels | ||||||||
""" | ||||||||
function shufflenet(groups, width_scale, num_classes; in_channels=3) | ||||||||
init_block_channels = 24 | ||||||||
layers = [4, 8, 4] | ||||||||
|
||||||||
if groups == 1 | ||||||||
channels_per_layers = [144, 288, 576] | ||||||||
elseif groups == 2 | ||||||||
channels_per_layers = [200, 400, 800] | ||||||||
elseif groups == 3 | ||||||||
channels_per_layers = [240, 480, 960] | ||||||||
elseif groups == 4 | ||||||||
channels_per_layers = [272, 544, 1088] | ||||||||
elseif groups == 8 | ||||||||
channels_per_layers = [384, 768, 1536] | ||||||||
else | ||||||||
return error("The number of groups is not supported. Groups = ", groups) | ||||||||
end | ||||||||
|
||||||||
channels = [] | ||||||||
for i in eachindex(layers) | ||||||||
char = [channels_per_layers[i]] | ||||||||
new = repeat(char, layers[i]) | ||||||||
push!(channels, new) | ||||||||
end | ||||||||
|
||||||||
if width_scale != 1.0 | ||||||||
channels = channels*width_scale | ||||||||
|
||||||||
init_block_channels::Integer = trunc(init_block_channels * width_scale) | ||||||||
end | ||||||||
|
||||||||
net = ShuffleNet( | ||||||||
channels, | ||||||||
init_block_channels, | ||||||||
groups; | ||||||||
in_channels, | ||||||||
num_classes) | ||||||||
|
||||||||
return net | ||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This type constraint is too restrictive. If
ChannelShuffle
works for all number types than it should reflect that. Generally all utility functions in Metalhead need to be GPU-compatible too. The renaming is a suggestion for how to make this function more "Julian", since it's not a callable type (which would be PascalCase) but a plain function. Lastly, how does this handle 3D inputs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think about it when writing the function, so for a 3D inputs, a batch of grey images, would be necessary to artificially create a channel dimension.