Skip to content

Updates to outdims #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5661580
Updates to outdims for normalisation and generic functions
darsnack Aug 5, 2020
e34111b
Added tests for normalisation outdims
darsnack Aug 5, 2020
0e36e61
Added tests
darsnack Aug 5, 2020
3f893f0
Refactor outdims code to outdims.jl
darsnack Aug 6, 2020
3b02621
Updated to use _handle_batch. Need to update testing.
darsnack Aug 7, 2020
09fc012
Added batch handling for Chain. Refactored outdims tests.
darsnack Aug 9, 2020
d087ca5
Added global and adaptive pooling outdims.
darsnack Aug 9, 2020
0d8f0d0
Added outdims(::SkipConnection)
darsnack Aug 9, 2020
33b00d4
Updated Chain outdims to work for vectors/tuples of layers too
darsnack Aug 9, 2020
7e0d274
Updated docs
darsnack Aug 9, 2020
13c0c70
Updated _handle_batch to avoid closures
darsnack Sep 25, 2020
615cc75
Updated with docs changes + doctests
darsnack Nov 15, 2020
e7fd419
Updates to docstrings, etc. for outdims
darsnack Nov 16, 2020
a4f4757
Remove "spatial dimensions" phrasing from docstrings for outdims.
darsnack Nov 16, 2020
87c6387
Added Nil-based outdims implementation
lorenzoh Sep 24, 2020
8c95fe5
Merge branch 'master' into outdims-nil
darsnack Sep 26, 2020
26462fc
Remove preserve_batch
darsnack Nov 16, 2020
0391ac0
Added docstring and doctests. Small bug fixes
darsnack Nov 19, 2020
657cf12
Updated docs and add some minor changes for normalization.
darsnack Nov 21, 2020
9433ff3
Removed Logging dependency
darsnack Dec 1, 2020
fddf75a
Removed callable tuple def
darsnack Dec 1, 2020
5217049
Group unary op defs for Nil
darsnack Dec 23, 2020
30d5cb8
Group binary op defs for Nil
darsnack Dec 23, 2020
afb4acd
Updated Nil to use promote_rule and added tests for activation functions
darsnack Dec 23, 2020
e105cc3
Removed complex batch handling for outdims in favor a simple kwarg
darsnack Dec 23, 2020
0f73014
Updated to use Base.conj and Base.convert for Nil
darsnack Dec 23, 2020
e5866cb
Specialize outdims on tuple isize
darsnack Dec 26, 2020
971004e
Remove dangling outdims references in basic.jl
darsnack Dec 26, 2020
d095919
Rework example, remove export, padbatch=false default
darsnack Dec 26, 2020
ccca623
Rename outdims -> outputsize
darsnack Dec 26, 2020
5d47cfc
Add deprecation for outdims
darsnack Dec 26, 2020
3a3574d
Fix doctest for outputsize
darsnack Dec 26, 2020
2792559
Update docstring for outputsize
darsnack Dec 26, 2020
324ecde
Fix docs and deps for outputsize
darsnack Dec 26, 2020
998861a
Update src/deprecations.jl
darsnack Dec 26, 2020
8d66583
Added missing kwarg to specialized outputsize methods
darsnack Dec 26, 2020
a08bda1
Fix outputsize method ambiguity
darsnack Dec 26, 2020
438db24
Merge remote-tracking branch 'origin/master'
darsnack Dec 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ Flux.glorot_normal

Flux provides some utility functions to help you generate models in an automated fashion.

[`outdims`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref)
[`outputsize`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref)
when applied to input samples of a given size. This is achieved by passing a "dummy" array into
the model that preserves size information without running any computation.
`outdims(f, isize)` works for all layers (including custom layers) out of the box.
By default, `isize` expects the batch dimension,
but you can exclude the batch size with `outdims(f, isize; padbatch=true)` (assuming it to be one).
`outputsize(f, inputsize)` works for all layers (including custom layers) out of the box.
By default, `inputsize` expects the batch dimension,
but you can exclude the batch size with `outputsize(f, inputsize; padbatch=true)` (assuming it to be one).

Using this utility function lets you automate model building for various inputs like so:
```julia
Expand All @@ -71,7 +71,7 @@ function make_model(width, height, inchannels, nclasses;

# compute the output dimensions for the conv layers
# use padbatch=true to set the batch dimension to 1
conv_outsize = outdims(conv_layers, (width, height, nchannels); padbatch=true)
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)

# the input dimension to Dense is programatically calculated from
# width, height, and nchannels
Expand All @@ -80,7 +80,7 @@ end
```

```@docs
Flux.outdims
Flux.outputsize
```

## Model Abstraction
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")

include("outdims.jl")
include("outputsize.jl")

include("data/Data.jl")

Expand Down
32 changes: 16 additions & 16 deletions src/outdims.jl → src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,57 +45,57 @@ end # module
using .NilNumber: Nil, nil

"""
outdims(m, isize; padbatch=false)
outputsize(m, inputsize::Tuple; padbatch=false)

Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results).
`isize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`).
If `m` is a `Tuple` or `Vector`, `outdims` treats `m` like a `Chain`.
Calculate the output size of model/function `m` given an input of size `inputsize` (w/o computing results).
`inputsize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`).
If `m` is a `Tuple` or `Vector`, `outputsize` treats `m` like a `Chain`.

*Note*: this method should work out of the box for custom layers.

# Examples
```jldoctest
julia> outdims(Dense(10, 4), (10,); padbatch=true)
julia> outputsize(Dense(10, 4), (10,); padbatch=true)
(4, 1)

julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));

julia> m(randn(Float32, 10, 10, 3, 64)) |> size
(6, 6, 32, 64)

julia> outdims(m, (10, 10, 3); padbatch=true)
julia> outputsize(m, (10, 10, 3); padbatch=true)
(6, 6, 32, 1)

julia> outdims(m, (10, 10, 3, 64))
julia> outputsize(m, (10, 10, 3, 64))
(6, 6, 32, 64)

julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
DimensionMismatch("Input channels must match! (7 vs. 3)")

julia> outdims([Dense(10, 4), Dense(4, 2)], (10, 1))
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
(2, 1)

julia> using LinearAlgebra: norm

julia> f(x) = x ./ norm.(eachcol(x));

julia> outdims(f, (10, 1)) # manually specify batch size as 1
julia> outputsize(f, (10, 1)) # manually specify batch size as 1
(10, 1)

julia> outdims(f, (10,); padbatch=true) # no need to mention batch size
julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
(10, 1)
```
"""
function outdims(m, isize::Tuple; padbatch=false)
isize = padbatch ? (isize..., 1) : isize
function outputsize(m, inputsize::Tuple; padbatch=false)
inputsize = padbatch ? (inputsize..., 1) : inputsize

return size(m(fill(nil, isize)))
return size(m(fill(nil, inputsize)))
end

## make tuples and vectors be like Chains

outdims(m::Tuple, isize) = outdims(Chain(m...), isize)
outdims(m::AbstractVector, isize) = outdims(Chain(m...), isize)
outputsize(m::Tuple, inputsize) = outputsize(Chain(m...), inputsize)
outputsize(m::AbstractVector, inputsize) = outputsize(Chain(m...), inputsize)

## bypass statistics in normalization layers

Expand Down
134 changes: 0 additions & 134 deletions test/outdims.jl

This file was deleted.

134 changes: 134 additions & 0 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
@testset "basic" begin
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
@test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1)

m = Dense(10, 5)
@test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1)
@test outputsize(m, (10,); padbatch=true) == (5, 1)

m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2))
@test outputsize(m, (10,); padbatch=true) == (2, 1)
@test outputsize(m, (10, 30)) == (2, 30)

m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2))
@test_throws DimensionMismatch outputsize(m, (10,))

m = Flux.Diagonal(10)
@test outputsize(m, (10, 1)) == (10, 1)

m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
@test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1)

m = flatten
@test outputsize(m, (5, 5, 3, 10)) == (75, 10)

m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10))
@test outputsize(m, (10, 10, 3, 50)) == (10, 50)
@test outputsize(m, (10, 10, 3, 2)) == (10, 2)

m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3))
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
end

@testset "activations" begin
@testset for f in [celu, elu, gelu, hardsigmoid, hardtanh,
leakyrelu, lisht, logcosh, logσ, mish,
relu, relu6, rrelu, selu, σ, softplus,
softshrink, softsign, swish, tanhshrink, trelu]
@test outputsize(Dense(10, 5, f), (10, 1)) == (5, 1)
end
end

@testset "conv" begin
m = Conv((3, 3), 3 => 16)
@test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1)
m = Conv((3, 3), 3 => 16; stride = 2)
@test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1)
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3)
@test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1)
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1)
@test_throws DimensionMismatch outputsize(m, (5, 5, 2))
@test outputsize(m, (5, 5, 3, 100)) == (4, 4, 16, 100)

m = ConvTranspose((3, 3), 3 => 16)
@test outputsize(m, (8, 8, 3, 1)) == (10, 10, 16, 1)
m = ConvTranspose((3, 3), 3 => 16; stride = 2)
@test outputsize(m, (2, 2, 3, 1)) == (5, 5, 16, 1)
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3)
@test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1)
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test outputsize(m, (4, 4, 3, 1)) == (5, 5, 16, 1)

m = DepthwiseConv((3, 3), 3 => 6)
@test outputsize(m, (10, 10, 3, 1)) == (8, 8, 6, 1)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2)
@test outputsize(m, (5, 5, 3, 1)) == (2, 2, 6, 1)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3)
@test outputsize(m, (5, 5, 3, 1)) == (5, 5, 6, 1)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2)
@test outputsize(m, (5, 5, 3, 1)) == (4, 4, 6, 1)

m = CrossCor((3, 3), 3 => 16)
@test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1)
m = CrossCor((3, 3), 3 => 16; stride = 2)
@test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1)
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3)
@test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1)
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1)

m = AdaptiveMaxPool((2, 2))
@test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1)

m = AdaptiveMeanPool((2, 2))
@test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1)

m = GlobalMaxPool()
@test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1)

m = GlobalMeanPool()
@test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1)

m = MaxPool((2, 2))
@test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1)
m = MaxPool((2, 2); stride = 1)
@test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1)
m = MaxPool((2, 2); stride = 2, pad = 3)
@test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1)

m = MeanPool((2, 2))
@test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1)
m = MeanPool((2, 2); stride = 1)
@test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1)
m = MeanPool((2, 2); stride = 2, pad = 3)
@test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1)
end

@testset "normalisation" begin
m = Dropout(0.1)
@test outputsize(m, (10, 10)) == (10, 10)
@test outputsize(m, (10,); padbatch=true) == (10, 1)

m = AlphaDropout(0.1)
@test outputsize(m, (10, 10)) == (10, 10)
@test outputsize(m, (10,); padbatch=true) == (10, 1)

m = LayerNorm(32)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)

m = BatchNorm(3)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)

m = InstanceNorm(3)
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)

if VERSION >= v"1.1"
m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end
end
Loading