-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ShuffleNet model #258
base: master
Are you sure you want to change the base?
Conversation
ShuffleNet model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! This is a great start, and the next steps would be adding tests + better matching the code style of the rest of the repo.
src/convnets/shufflenet.jl
Outdated
function ChannelShuffle(x::Array{Float32, 4}, g::Int) | ||
width, height, channels, batch = size(x) | ||
channels_per_group = channels÷g | ||
if (channels % g) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (channels % g) == 0 | |
if channels % g == 0 |
We have a JuliaFormatter config in this repo, so make sure to run that before pushing your code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied.
src/convnets/shufflenet.jl
Outdated
- `channels`: number of channels | ||
- `groups`: number of groups | ||
""" | ||
function ChannelShuffle(x::Array{Float32, 4}, g::Int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function ChannelShuffle(x::Array{Float32, 4}, g::Int) | |
function channel_shuffle(x::AbstractArray{Float32, 4}, g::Int) |
This type constraint is too restrictive. If ChannelShuffle
works for all number types than it should reflect that. Generally all utility functions in Metalhead need to be GPU-compatible too. The renaming is a suggestion for how to make this function more "Julian", since it's not a callable type (which would be PascalCase) but a plain function. Lastly, how does this handle 3D inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think about it when writing the function, so for a 3D inputs, a batch of grey images, would be necessary to artificially create a channel dimension.
src/convnets/shufflenet.jl
Outdated
BatchNorm(mid_channels), | ||
NNlib.relu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchNorm(mid_channels), | |
NNlib.relu, | |
BatchNorm(mid_channels, relu), |
relu
is already in scope because of using NNlib
and fusing it into the preceeding norm is slightly more efficient. Also, is the activation function not configurable for ShuffleNet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
src/convnets/shufflenet.jl
Outdated
m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()), | ||
BatchNorm(mid_channels), | ||
NNlib.relu, | ||
x -> ChannelShuffle(x, groups), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x -> ChannelShuffle(x, groups), | |
Base.Fix2(channel_shuffle, groups), |
Will be easier on the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
src/convnets/shufflenet.jl
Outdated
NNlib.relu) | ||
|
||
if downsample | ||
m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) | |
m = Parallel(cat_channels, m, MeanPool((3,3); pad=SamePad(), stride=2)) |
We have cat_channels
for this exact case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
src/convnets/shufflenet.jl
Outdated
|
||
model = Chain(features...) | ||
|
||
return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) | |
return Chain(model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) |
The general modus operandi of this library has been to create named types for the top-level model and wrap the underlying Chain with them. You can see this pattern in the files for any of the other exported models.
For the suggestion. flatten
is only imported and not defined in Flux. It's preferable to use a symbol from the library that actually defined when that library is available (which MLUtils is, being a dep of Metalhead).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could I see an example? Sorry, I'm still a newbie using Julia, I looked to the rest of convnets and tried to code with a similar style, but there are still things I that still haven't fully understood.
Better matching the code style of the rest of Metalhead
I made the suggested changes |
Thanks for the updates. On a quick skim nothing stands out to me, can you add it to the test suite to finish off the PR? |
corrected typo
added missing includes
I'm working on this implementation of ShuffleNet from https://arxiv.org/abs/1707.01083.