Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ShuffleNet model #258

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open

Adding ShuffleNet model #258

wants to merge 17 commits into from

Conversation

RafaelT00
Copy link

@RafaelT00 RafaelT00 commented Nov 3, 2023

I'm working on this implementation of ShuffleNet from https://arxiv.org/abs/1707.01083.

@RafaelT00 RafaelT00 changed the title Adding Shuffle Adding ShuffleNet model Nov 3, 2023
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! This is a great start, and the next steps would be adding tests + better matching the code style of the rest of the repo.

function ChannelShuffle(x::Array{Float32, 4}, g::Int)
width, height, channels, batch = size(x)
channels_per_group = channels÷g
if (channels % g) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (channels % g) == 0
if channels % g == 0

We have a JuliaFormatter config in this repo, so make sure to run that before pushing your code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied.

- `channels`: number of channels
- `groups`: number of groups
"""
function ChannelShuffle(x::Array{Float32, 4}, g::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function ChannelShuffle(x::Array{Float32, 4}, g::Int)
function channel_shuffle(x::AbstractArray{Float32, 4}, g::Int)

This type constraint is too restrictive. If ChannelShuffle works for all number types than it should reflect that. Generally all utility functions in Metalhead need to be GPU-compatible too. The renaming is a suggestion for how to make this function more "Julian", since it's not a callable type (which would be PascalCase) but a plain function. Lastly, how does this handle 3D inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think about it when writing the function, so for a 3D inputs, a batch of grey images, would be necessary to artificially create a channel dimension.

Comment on lines 49 to 50
BatchNorm(mid_channels),
NNlib.relu,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BatchNorm(mid_channels),
NNlib.relu,
BatchNorm(mid_channels, relu),

relu is already in scope because of using NNlib and fusing it into the preceeding norm is slightly more efficient. Also, is the activation function not configurable for ShuffleNet?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()),
BatchNorm(mid_channels),
NNlib.relu,
x -> ChannelShuffle(x, groups),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x -> ChannelShuffle(x, groups),
Base.Fix2(channel_shuffle, groups),

Will be easier on the compiler.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

NNlib.relu)

if downsample
m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2))
m = Parallel(cat_channels, m, MeanPool((3,3); pad=SamePad(), stride=2))

We have cat_channels for this exact case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.


model = Chain(features...)

return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes))
return Chain(model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes))

The general modus operandi of this library has been to create named types for the top-level model and wrap the underlying Chain with them. You can see this pattern in the files for any of the other exported models.

For the suggestion. flatten is only imported and not defined in Flux. It's preferable to use a symbol from the library that actually defined when that library is available (which MLUtils is, being a dep of Metalhead).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I see an example? Sorry, I'm still a newbie using Julia, I looked to the rest of convnets and tried to code with a similar style, but there are still things I that still haven't fully understood.

@RafaelT00
Copy link
Author

I made the suggested changes

@ToucheSir
Copy link
Member

Thanks for the updates. On a quick skim nothing stands out to me, can you add it to the test suite to finish off the PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants