diff --git a/LocalPreferences.toml b/LocalPreferences.toml deleted file mode 100644 index 59b2758..0000000 --- a/LocalPreferences.toml +++ /dev/null @@ -1,2 +0,0 @@ -[CPUSummary] -hwloc = true diff --git a/Project.toml b/Project.toml index a3e5e1b..ee9a77f 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,6 @@ julia = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index 4aa9371..2d8c920 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,7 +13,12 @@ makedocs(; canonical = "https://PumasAI.github.io/SimpleChains.jl", assets = String[], ), - pages = ["Home" => "index.md"], + pages = ["Home" => "index.md", + "Examples" => [ + "examples/smallmlp.md", + "examples/mnist.md", + ] +], ) deploydocs(; repo = "github.com/PumasAI/SimpleChains.jl", devbranch = "main") diff --git a/docs/src/examples/mnist.md b/docs/src/examples/mnist.md new file mode 100644 index 0000000..8a565a5 --- /dev/null +++ b/docs/src/examples/mnist.md @@ -0,0 +1,97 @@ +# MNIST - Convolutions + +First, we load the data using [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl): +```julia +using MLDatasets +xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32); +xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32); +size(xtest3) +# (28, 28, 60000) +extrema(ytrain0) # digits, 0,...,9 +# (0, 9) +``` +The covariate data (`x`) were named `3` as these are three-dimensional arrays, containing the height x width x number of images. +The training data are vectors indicating the digit. +```julia +xtrain4 = reshape(xtrain3, 28, 28, 1, :); +xtest4 = reshape(xtest3, 28, 28, 1, :); +ytrain1 = UInt32.(ytrain0 .+ 1); +ytest1 = UInt32.(ytest0 .+ 1); +``` +SimpleChains' convolutional layers expect that we have a channels-in dimension, so we shape the images to be four dimensional +It also currently defaults to 1-based indexing for its categories, so we shift all categories by 1. + +We now define our model, LeNet5: +```julia +using SimpleChains + +lenet = SimpleChain( + (static(28), static(28), static(1)), + SimpleChains.Conv(SimpleChains.relu, (5, 5), 6), + SimpleChains.MaxPool(2, 2), + SimpleChains.Conv(SimpleChains.relu, (5, 5), 16), + SimpleChains.MaxPool(2, 2), + Flatten(3), + TurboDense(SimpleChains.relu, 120), + TurboDense(SimpleChains.relu, 84), + TurboDense(identity, 10), +) + +lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1)); +``` +We define the inputs as being statically sized `(28,28,1)` images. +Specifying the input sizes allows these to be checked. +Making them static, which we can do either in our simple chain, or by adding +static sizing to the images themselves using a package like [StrideArrays.jl](https://github.com/JuliaSIMD/StrideArrays.jl) +or [HybridArrays.jl](git@github.com:JuliaArrays/HybridArrays.jl.git). These packages are recomennded +for allowing you to mix dynamic and static sizes; the batch size should probably +be left dynamic, as you're unlikely to want to specialize code generation on this, +given that it is likely to vary, increasing compile times while being unlikely to +improve runtimes. + +In `SimpleChains`, the parameters are not a part of the model, but live as a +separate vector that you can pass around to optimizers of your choosing. +If you specified the input size, you create a random initial parameter vector +corresponding to the model: +```julia +@time p = SimpleChains.init_params(lenet); +``` +The convolutional layers are initialized with a Glorot (Xavier) unifirom distribution, +while the dense layers are initialized with a Glorot (Xaviar) normal distribution. +Biases are initialized to zero. +Because the number of parameters can be a function of the input size, these must +be provided if you didn't specify input dimension. For example: +```julia +@time p = SimpleChains.init_params(lenet, size(xtrain4)); +``` + +To allow training to use multiple threads, you can create a gradient matrix, with +a number of rows equal to the length of the parameter vector `p`, and one column +per thread. For example: +```julia +estimated_num_cores = (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1)); +G = similar(p, length(p), min(Threads.nthreads(), estimated_num_cores); +``` +Here, we're estimating that the number of physical cores is half the number of threads +on an `x86_64` system, which is true for most -- but not all!!! -- of them. +Otherwise, we're assuming it is equal to the number of threads. This is of course also +likely to be wrong, e.g. recent Power CPUs may habe 4 or even 8 threads per core. +You may wish to change this, or use [Hwloc.jl](https://github.com/JuliaParallel/Hwloc.jl) for an accurate number. + +Now that this is all said and done, we can train for `10` epochs using the `ADAM` optimizer +with a learning rate of `3e-4`, and then assess the accuracy and loss of both the training +and test data: +```julia +@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); +SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) +SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) +``` +Training for an extra 10 epochs should be fast on most systems. Performance is currently known +to be poor on the M1 (PRs welcome, otherwise we'll look into this eventually), but should be +good/great on systems with AVX2/AVX512: +```julia +@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); +SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) +SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) +``` + diff --git a/docs/src/examples/smallmlp.md b/docs/src/examples/smallmlp.md new file mode 100644 index 0000000..6ee40d2 --- /dev/null +++ b/docs/src/examples/smallmlp.md @@ -0,0 +1,3 @@ +# Small Multi-Layer Perceptron + + diff --git a/docs/src/index.md b/docs/src/index.md index 1f01c46..b224874 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,3 +12,7 @@ Documentation for [SimpleChains](https://github.com/PumasAI/SimpleChains.jl). ```@autodocs Modules = [SimpleChains] ``` + +```@contents +Pages = ["examples/smallmlp.md", "examples/mnist.md"] +``` diff --git a/src/optimize.jl b/src/optimize.jl index 9af9f74..0a77cdf 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -105,25 +105,16 @@ function shuffle_chain_valgrad_thread!( tgt = target(loss) # @show size(tgt) tgtpb = preserve_buffer(tgt) - Xpb = preserve_buffer(Xp) - # Xsz = Base.front(size(Xp)) - # eltx = eltype(Xp) eltgt = eltype(tgt) - # szeltx = sizeof(eltx) szeltgt = sizeof(eltgt) - # Xtmp = PtrArray(Ptr{eltx}(pm), (Xsz..., lastdim)) - # Xlen = tsprod(Xsz) - # pXtmp = pointer(Xtmp) - # pm += align(sizeof(eltgt) * Xlen * lastdim) tgtsz = Base.front(size(tgt)) tgttmp = PtrArray(Ptr{eltgt}(pm), (tgtsz..., lastdim)) ptgttmp = pointer(tgttmp) tgtlen = tsprod(tgtsz) pm += align(szeltgt * tgtlen * lastdim) - # pX = pointer(Xp) ptgt = pointer(tgt) - GC.@preserve tgtpb Xpb begin + GC.@preserve tgtpb begin for i = fm1:l-1 @inbounds j = perm[i] # `perm` and `j` are zero-based # @show i, j diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..1cd3689 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,6 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/mnist.jl b/test/mnist.jl new file mode 100644 index 0000000..3826d9d --- /dev/null +++ b/test/mnist.jl @@ -0,0 +1,59 @@ +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" +using SimpleChains, MLDatasets, Test + +lenet = SimpleChain( + (static(28), static(28), static(1)), + SimpleChains.Conv(SimpleChains.relu, (5, 5), 6), + SimpleChains.MaxPool(2, 2), + SimpleChains.Conv(SimpleChains.relu, (5, 5), 16), + SimpleChains.MaxPool(2, 2), + Flatten(3), + TurboDense(SimpleChains.relu, 120), + TurboDense(SimpleChains.relu, 84), + TurboDense(identity, 10), +) +# 3d and 0-indexed +xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32); +xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32); +xtrain4 = reshape(xtrain3, 28, 28, 1, :); +xtest4 = reshape(xtest3, 28, 28, 1, :); +ytrain1 = UInt32.(ytrain0 .+ 1); +ytest1 = UInt32.(ytest0 .+ 1); +lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1)); + +@test SimpleChains.outputdim(lenet, size(xtrain4)) == (10,length(ytrain1)); +@test SimpleChains.outputdim(lenet, size(xtest4)) == (10,length(ytest1)); + +# initialize parameters +@time p = SimpleChains.init_params(lenet); + +@testset "Cache Corrupting Results" begin + g = similar(p) + subset = 1:200 + x = xtrain4[:,:,:,subset] + y = ytrain1[subset] + letnetloss = SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)) + lenetloss.memory .= 0x00 + valgrad!(g, lenetloss, x, p) + g2 = similar(g) + lenetloss.memory .= 0xff + valgrad!(g2, lenetloss, x, p) + @test g == g2 +end + +# initialize a gradient buffer matrix; number of columns places an upper bound +# on the number of threads used. +G = similar(p, length(p), min(Threads.nthreads(), (Sys.CPU_THREADS ÷ ((Sys.ARCH === :x86_64) + 1)))); +# train +@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); +# assess training and test loss +a0, l0 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) +a1, l1 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) +# train without additional memory allocations +@time SimpleChains.train_batched!(G, p, lenetloss, xtrain4, SimpleChains.ADAM(3e-4), 10); +# assess training and test loss +a2, l2 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) +a3, l3 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) +@test a2 > a0 > 0.96 +@test a3 > a1 > 0.96 + diff --git a/test/runtests.jl b/test/runtests.jl index d9d2988..2aae6d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -419,33 +419,9 @@ SquaredLoss""" @test Ac == g end @testset "LeNet" begin - N = 20 - nclasses = 10 - x = randn(28, 28, 1, N) - lenet = SimpleChain( - (static(28), static(28), static(1)), - SimpleChains.Conv(SimpleChains.relu, (5, 5), 6), - SimpleChains.MaxPool(2, 2), - SimpleChains.Conv(SimpleChains.relu, (5, 5), 16), - SimpleChains.MaxPool(2, 2), - Flatten(3), - TurboDense(SimpleChains.relu, 120), - TurboDense(SimpleChains.relu, 84), - TurboDense(identity, nclasses), - ) - SimpleChains.outputdim(lenet, size(x)) - # y = - # d = Simple - p = SimpleChains.init_params(lenet, size(x)) - lenet(x, p) - g = similar(p) - y = rand(one(UInt32):UInt32(nclasses), N) - lenet.memory .= 0x00 - valgrad!(g, SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)), x, p) - g2 = similar(g) - lenet.memory .= 0xff - valgrad!(g2, SimpleChains.add_loss(lenet, SimpleChains.LogitCrossEntropyLoss(y)), x, p) - @test g == g2 + include("mnist.jl") end end -Aqua.test_all(SimpleChains, ambiguities = false, project_toml_formatting = false) #TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped +# TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped +# For now, there are the tests at the start. +Aqua.test_all(SimpleChains, ambiguities = false, project_toml_formatting = false)