Skip to content

Commit 8072cc9

Browse files
Merge pull request #46 from PumasAI/densesubset
Docs and examples
2 parents 679f931 + 285988f commit 8072cc9

File tree

10 files changed

+180
-42
lines changed

10 files changed

+180
-42
lines changed

LocalPreferences.toml

Lines changed: 0 additions & 2 deletions
This file was deleted.

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ julia = "1.5"
4545

4646
[extras]
4747
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
48-
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
4948
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5049
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5150
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

docs/make.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ makedocs(;
1313
canonical = "https://PumasAI.github.io/SimpleChains.jl",
1414
assets = String[],
1515
),
16-
pages = ["Home" => "index.md"],
16+
pages = ["Home" => "index.md",
17+
"Examples" => [
18+
"examples/smallmlp.md",
19+
"examples/mnist.md",
20+
]
21+
],
1722
)
1823

1924
deploydocs(; repo = "github.com/PumasAI/SimpleChains.jl", devbranch = "main")

docs/src/examples/mnist.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# MNIST - Convolutions
2+
3+
First, we load the data using [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl):
4+
```julia
5+
using MLDatasets
6+
xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32);
7+
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32);
8+
size(xtest3)
9+
# (28, 28, 60000)
10+
extrema(ytrain0) # digits, 0,...,9
11+
# (0, 9)
12+
```
13+
The covariate data (`x`) were named `3` as these are three-dimensional arrays, containing the height x width x number of images.
14+
The training data are vectors indicating the digit.
15+
```julia
16+
xtrain4 = reshape(xtrain3, 28, 28, 1, :);
17+
xtest4 = reshape(xtest3, 28, 28, 1, :);
18+
ytrain1 = UInt32.(ytrain0 .+ 1);
19+
ytest1 = UInt32.(ytest0 .+ 1);
20+
```
21+
SimpleChains' convolutional layers expect that we have a channels-in dimension, so we shape the images to be four dimensional
22+
It also currently defaults to 1-based indexing for its categories, so we shift all categories by 1.
23+
24+
We now define our model, LeNet5:
25+
```julia
26+
using SimpleChains
27+
28+
lenet = SimpleChain(
29+
(static(28), static(28), static(1)),
30+
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
31+
SimpleChains.MaxPool(2, 2),
32+
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
33+
SimpleChains.MaxPool(2, 2),
34+
Flatten(3),
35+
TurboDense(SimpleChains.relu, 120),
36+
TurboDense(SimpleChains.relu, 84),
37+
TurboDense(identity, 10),
38+
)
39+
40+
lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1));
41+
```
42+
We define the inputs as being statically sized `(28,28,1)` images.
43+
Specifying the input sizes allows these to be checked.
44+
Making them static, which we can do either in our simple chain, or by adding
45+
static sizing to the images themselves using a package like [StrideArrays.jl](https://github.com/JuliaSIMD/StrideArrays.jl)
46+
or [HybridArrays.jl]([email protected]:JuliaArrays/HybridArrays.jl.git). These packages are recomennded
47+
for allowing you to mix dynamic and static sizes; the batch size should probably
48+
be left dynamic, as you're unlikely to want to specialize code generation on this,
49+
given that it is likely to vary, increasing compile times while being unlikely to
50+
improve runtimes.
51+
52+
In `SimpleChains`, the parameters are not a part of the model, but live as a
53+
separate vector that you can pass around to optimizers of your choosing.
54+
If you specified the input size, you create a random initial parameter vector
55+
corresponding to the model:
56+
```julia
57+
@time p = SimpleChains.init_params(lenet);
58+
```
59+
The convolutional layers are initialized with a Glorot (Xavier) unifirom distribution,
60+
while the dense layers are initialized with a Glorot (Xaviar) normal distribution.
61+
Biases are initialized to zero.
62+
Because the number of parameters can be a function of the input size, these must
63+
be provided if you didn't specify input dimension. For example:
64+
```julia
65+
@time p = SimpleChains.init_params(lenet, size(xtrain4));
66+
```
67+
68+
To allow training to use multiple threads, you can create a gradient matrix, with
69+
a number of rows equal to the length of the parameter vector `p`, and one column
70+
per thread. For example:
71+
```julia
72+
estimated_num_cores = (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1));
73+
G = similar(p, length(p), min(Threads.nthreads(), estimated_num_cores);
74+
```
75+
Here, we're estimating that the number of physical cores is half the number of threads
76+
on an `x86_64` system, which is true for most -- but not all!!! -- of them.
77+
Otherwise, we're assuming it is equal to the number of threads. This is of course also
78+
likely to be wrong, e.g. recent Power CPUs may habe 4 or even 8 threads per core.
79+
You may wish to change this, or use [Hwloc.jl](https://github.com/JuliaParallel/Hwloc.jl) for an accurate number.
80+
81+
Now that this is all said and done, we can train for `10` epochs using the `ADAM` optimizer
82+
with a learning rate of `3e-4`, and then assess the accuracy and loss of both the training
83+
and test data:
84+
```julia
85+
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10);
86+
SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p)
87+
SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p)
88+
```
89+
Training for an extra 10 epochs should be fast on most systems. Performance is currently known
90+
to be poor on the M1 (PRs welcome, otherwise we'll look into this eventually), but should be
91+
good/great on systems with AVX2/AVX512:
92+
```julia
93+
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10);
94+
SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p)
95+
SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p)
96+
```
97+

docs/src/examples/smallmlp.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Small Multi-Layer Perceptron
2+
3+

docs/src/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ Documentation for [SimpleChains](https://github.com/PumasAI/SimpleChains.jl).
1212
```@autodocs
1313
Modules = [SimpleChains]
1414
```
15+
16+
```@contents
17+
Pages = ["examples/smallmlp.md", "examples/mnist.md"]
18+
```

src/optimize.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,16 @@ function shuffle_chain_valgrad_thread!(
105105
tgt = target(loss)
106106
# @show size(tgt)
107107
tgtpb = preserve_buffer(tgt)
108-
Xpb = preserve_buffer(Xp)
109-
# Xsz = Base.front(size(Xp))
110-
# eltx = eltype(Xp)
111108
eltgt = eltype(tgt)
112-
# szeltx = sizeof(eltx)
113109
szeltgt = sizeof(eltgt)
114110

115-
# Xtmp = PtrArray(Ptr{eltx}(pm), (Xsz..., lastdim))
116-
# Xlen = tsprod(Xsz)
117-
# pXtmp = pointer(Xtmp)
118-
# pm += align(sizeof(eltgt) * Xlen * lastdim)
119111
tgtsz = Base.front(size(tgt))
120112
tgttmp = PtrArray(Ptr{eltgt}(pm), (tgtsz..., lastdim))
121113
ptgttmp = pointer(tgttmp)
122114
tgtlen = tsprod(tgtsz)
123115
pm += align(szeltgt * tgtlen * lastdim)
124-
# pX = pointer(Xp)
125116
ptgt = pointer(tgt)
126-
GC.@preserve tgtpb Xpb begin
117+
GC.@preserve tgtpb begin
127118
for i = fm1:l-1
128119
@inbounds j = perm[i] # `perm` and `j` are zero-based
129120
# @show i, j

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
5+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/mnist.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
2+
using SimpleChains, MLDatasets, Test
3+
4+
lenet = SimpleChain(
5+
(static(28), static(28), static(1)),
6+
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
7+
SimpleChains.MaxPool(2, 2),
8+
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
9+
SimpleChains.MaxPool(2, 2),
10+
Flatten(3),
11+
TurboDense(SimpleChains.relu, 120),
12+
TurboDense(SimpleChains.relu, 84),
13+
TurboDense(identity, 10),
14+
)
15+
# 3d and 0-indexed
16+
xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32);
17+
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32);
18+
xtrain4 = reshape(xtrain3, 28, 28, 1, :);
19+
xtest4 = reshape(xtest3, 28, 28, 1, :);
20+
ytrain1 = UInt32.(ytrain0 .+ 1);
21+
ytest1 = UInt32.(ytest0 .+ 1);
22+
lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1));
23+
24+
@test SimpleChains.outputdim(lenet, size(xtrain4)) == (10,length(ytrain1));
25+
@test SimpleChains.outputdim(lenet, size(xtest4)) == (10,length(ytest1));
26+
27+
# initialize parameters
28+
@time p = SimpleChains.init_params(lenet);
29+
30+
@testset "Cache Corrupting Results" begin
31+
g = similar(p)
32+
subset = 1:200
33+
x = xtrain4[:,:,:,subset]
34+
y = ytrain1[subset]
35+
letnetloss = SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y))
36+
lenetloss.memory .= 0x00
37+
valgrad!(g, lenetloss, x, p)
38+
g2 = similar(g)
39+
lenetloss.memory .= 0xff
40+
valgrad!(g2, lenetloss, x, p)
41+
@test g == g2
42+
end
43+
44+
# initialize a gradient buffer matrix; number of columns places an upper bound
45+
# on the number of threads used.
46+
G = similar(p, length(p), min(Threads.nthreads(), (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1))));
47+
# train
48+
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10);
49+
# assess training and test loss
50+
a0, l0 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p)
51+
a1, l1 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p)
52+
# train without additional memory allocations
53+
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10);
54+
# assess training and test loss
55+
a2, l2 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p)
56+
a3, l3 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p)
57+
@test a2 > a0 > 0.96
58+
@test a3 > a1 > 0.96
59+

test/runtests.jl

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -419,33 +419,9 @@ SquaredLoss"""
419419
@test Ac == g
420420
end
421421
@testset "LeNet" begin
422-
N = 20
423-
nclasses = 10
424-
x = randn(28, 28, 1, N)
425-
lenet = SimpleChain(
426-
(static(28), static(28), static(1)),
427-
SimpleChains.Conv(SimpleChains.relu, (5, 5), 6),
428-
SimpleChains.MaxPool(2, 2),
429-
SimpleChains.Conv(SimpleChains.relu, (5, 5), 16),
430-
SimpleChains.MaxPool(2, 2),
431-
Flatten(3),
432-
TurboDense(SimpleChains.relu, 120),
433-
TurboDense(SimpleChains.relu, 84),
434-
TurboDense(identity, nclasses),
435-
)
436-
SimpleChains.outputdim(lenet, size(x))
437-
# y =
438-
# d = Simple
439-
p = SimpleChains.init_params(lenet, size(x))
440-
lenet(x, p)
441-
g = similar(p)
442-
y = rand(one(UInt32):UInt32(nclasses), N)
443-
lenet.memory .= 0x00
444-
valgrad!(g, SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)), x, p)
445-
g2 = similar(g)
446-
lenet.memory .= 0xff
447-
valgrad!(g2, SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)), x, p)
448-
@test g == g2
422+
include("mnist.jl")
449423
end
450424
end
451-
Aqua.test_all(SimpleChains, ambiguities = false, project_toml_formatting = false) #TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped
425+
# TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped
426+
# For now, there are the tests at the start.
427+
Aqua.test_all(SimpleChains, ambiguities = false, project_toml_formatting = false)

0 commit comments

Comments
 (0)