-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #46 from PumasAI/densesubset
Docs and examples
- Loading branch information
Showing
10 changed files
with
180 additions
and
42 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# MNIST - Convolutions | ||
|
||
First, we load the data using [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl): | ||
```julia | ||
using MLDatasets | ||
xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32); | ||
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32); | ||
size(xtest3) | ||
# (28, 28, 60000) | ||
extrema(ytrain0) # digits, 0,...,9 | ||
# (0, 9) | ||
``` | ||
The covariate data (`x`) were named `3` as these are three-dimensional arrays, containing the height x width x number of images. | ||
The training data are vectors indicating the digit. | ||
```julia | ||
xtrain4 = reshape(xtrain3, 28, 28, 1, :); | ||
xtest4 = reshape(xtest3, 28, 28, 1, :); | ||
ytrain1 = UInt32.(ytrain0 .+ 1); | ||
ytest1 = UInt32.(ytest0 .+ 1); | ||
``` | ||
SimpleChains' convolutional layers expect that we have a channels-in dimension, so we shape the images to be four dimensional | ||
It also currently defaults to 1-based indexing for its categories, so we shift all categories by 1. | ||
|
||
We now define our model, LeNet5: | ||
```julia | ||
using SimpleChains | ||
|
||
lenet = SimpleChain( | ||
(static(28), static(28), static(1)), | ||
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6), | ||
SimpleChains.MaxPool(2, 2), | ||
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16), | ||
SimpleChains.MaxPool(2, 2), | ||
Flatten(3), | ||
TurboDense(SimpleChains.relu, 120), | ||
TurboDense(SimpleChains.relu, 84), | ||
TurboDense(identity, 10), | ||
) | ||
|
||
lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1)); | ||
``` | ||
We define the inputs as being statically sized `(28,28,1)` images. | ||
Specifying the input sizes allows these to be checked. | ||
Making them static, which we can do either in our simple chain, or by adding | ||
static sizing to the images themselves using a package like [StrideArrays.jl](https://github.com/JuliaSIMD/StrideArrays.jl) | ||
or [HybridArrays.jl]([email protected]:JuliaArrays/HybridArrays.jl.git). These packages are recomennded | ||
for allowing you to mix dynamic and static sizes; the batch size should probably | ||
be left dynamic, as you're unlikely to want to specialize code generation on this, | ||
given that it is likely to vary, increasing compile times while being unlikely to | ||
improve runtimes. | ||
|
||
In `SimpleChains`, the parameters are not a part of the model, but live as a | ||
separate vector that you can pass around to optimizers of your choosing. | ||
If you specified the input size, you create a random initial parameter vector | ||
corresponding to the model: | ||
```julia | ||
@time p = SimpleChains.init_params(lenet); | ||
``` | ||
The convolutional layers are initialized with a Glorot (Xavier) unifirom distribution, | ||
while the dense layers are initialized with a Glorot (Xaviar) normal distribution. | ||
Biases are initialized to zero. | ||
Because the number of parameters can be a function of the input size, these must | ||
be provided if you didn't specify input dimension. For example: | ||
```julia | ||
@time p = SimpleChains.init_params(lenet, size(xtrain4)); | ||
``` | ||
|
||
To allow training to use multiple threads, you can create a gradient matrix, with | ||
a number of rows equal to the length of the parameter vector `p`, and one column | ||
per thread. For example: | ||
```julia | ||
estimated_num_cores = (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1)); | ||
G = similar(p, length(p), min(Threads.nthreads(), estimated_num_cores); | ||
``` | ||
Here, we're estimating that the number of physical cores is half the number of threads | ||
on an `x86_64` system, which is true for most -- but not all!!! -- of them. | ||
Otherwise, we're assuming it is equal to the number of threads. This is of course also | ||
likely to be wrong, e.g. recent Power CPUs may habe 4 or even 8 threads per core. | ||
You may wish to change this, or use [Hwloc.jl](https://github.com/JuliaParallel/Hwloc.jl) for an accurate number. | ||
Now that this is all said and done, we can train for `10` epochs using the `ADAM` optimizer | ||
with a learning rate of `3e-4`, and then assess the accuracy and loss of both the training | ||
and test data: | ||
```julia | ||
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); | ||
SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) | ||
SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) | ||
``` | ||
Training for an extra 10 epochs should be fast on most systems. Performance is currently known | ||
to be poor on the M1 (PRs welcome, otherwise we'll look into this eventually), but should be | ||
good/great on systems with AVX2/AVX512: | ||
```julia | ||
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); | ||
SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) | ||
SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) | ||
``` | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Small Multi-Layer Perceptron | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[deps] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" | ||
using SimpleChains, MLDatasets, Test | ||
|
||
lenet = SimpleChain( | ||
(static(28), static(28), static(1)), | ||
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6), | ||
SimpleChains.MaxPool(2, 2), | ||
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16), | ||
SimpleChains.MaxPool(2, 2), | ||
Flatten(3), | ||
TurboDense(SimpleChains.relu, 120), | ||
TurboDense(SimpleChains.relu, 84), | ||
TurboDense(identity, 10), | ||
) | ||
# 3d and 0-indexed | ||
xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32); | ||
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32); | ||
xtrain4 = reshape(xtrain3, 28, 28, 1, :); | ||
xtest4 = reshape(xtest3, 28, 28, 1, :); | ||
ytrain1 = UInt32.(ytrain0 .+ 1); | ||
ytest1 = UInt32.(ytest0 .+ 1); | ||
lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1)); | ||
|
||
@test SimpleChains.outputdim(lenet, size(xtrain4)) == (10,length(ytrain1)); | ||
@test SimpleChains.outputdim(lenet, size(xtest4)) == (10,length(ytest1)); | ||
|
||
# initialize parameters | ||
@time p = SimpleChains.init_params(lenet); | ||
|
||
@testset "Cache Corrupting Results" begin | ||
g = similar(p) | ||
subset = 1:200 | ||
x = xtrain4[:,:,:,subset] | ||
y = ytrain1[subset] | ||
letnetloss = SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)) | ||
lenetloss.memory .= 0x00 | ||
valgrad!(g, lenetloss, x, p) | ||
g2 = similar(g) | ||
lenetloss.memory .= 0xff | ||
valgrad!(g2, lenetloss, x, p) | ||
@test g == g2 | ||
end | ||
|
||
# initialize a gradient buffer matrix; number of columns places an upper bound | ||
# on the number of threads used. | ||
G = similar(p, length(p), min(Threads.nthreads(), (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1)))); | ||
# train | ||
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); | ||
# assess training and test loss | ||
a0, l0 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) | ||
a1, l1 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) | ||
# train without additional memory allocations | ||
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); | ||
# assess training and test loss | ||
a2, l2 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) | ||
a3, l3 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) | ||
@test a2 > a0 > 0.96 | ||
@test a3 > a1 > 0.96 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters