diff --git a/Project.toml b/Project.toml index d515aaa..08c26cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleChains" uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" authors = ["Chris Elrod and contributors"] -version = "0.3.4" +version = "0.4.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/examples/mnist_lenet.jl b/examples/mnist_lenet.jl index 8ef8ba3..2031a86 100644 --- a/examples/mnist_lenet.jl +++ b/examples/mnist_lenet.jl @@ -48,7 +48,7 @@ G = SimpleChains.alloc_threaded_grad(lenetloss); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -# SimpleChains.init_params!(lenet, p); +# SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) @@ -56,11 +56,11 @@ SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) # lenet.memory .= 0; -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) @@ -75,11 +75,11 @@ lenetloss.memory .= 0x00; @time valgrad!(g1, lenetloss, xtrain, p) g0 == g1 lenet.memory .= 0; -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) diff --git a/src/penalty.jl b/src/penalty.jl index fd0755e..16e8fc8 100644 --- a/src/penalty.jl +++ b/src/penalty.jl @@ -47,8 +47,8 @@ end function init_params(Λ::AbstractPenalty, ::Type{T}; rng::AbstractRNG=local_rng()) where {T} init_params(getchain(Λ), nothing, T; rng) end -function init_params!(Λ::AbstractPenalty, x, id = nothing; rng::AbstractRNG=local_rng()) - init_params!(getchain(Λ), x, id; rng) +function init_params!(x, Λ::AbstractPenalty, id = nothing; rng::AbstractRNG=local_rng()) + init_params!(x, getchain(Λ), id; rng) end target(c::AbstractPenalty) = target(getchain(c)) diff --git a/src/simple_chain.jl b/src/simple_chain.jl index 6a7103b..0e04117 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -1,4 +1,5 @@ + struct InputDimUnknown end const InputDim = Union{InputDimUnknown,Tuple{Vararg{Integer}}} @@ -387,8 +388,8 @@ end Randomly initializes parameter vector `p` with input dim `id`. Input dim does not need to be specified if these were provided to the chain object itself. See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ -function init_params!( - chn::SimpleChain, x::AbstractVector, id = nothing; rng::AbstractRNG = local_rng() +@inline function init_params!( + x::AbstractVector, chn::SimpleChain, id = nothing; rng::AbstractRNG = local_rng() ) GC.@preserve x _init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng) return x @@ -398,14 +399,14 @@ function _init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG) _init_params!(Base.tail(layers), p, od, rng) end _init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing -function init_params( +@inline function init_params( Λ::SimpleChain, id::Union{Nothing,InputDim} = nothing, ::Type{T} = Float32; rng::AbstractRNG=local_rng() ) where {T} _id = chain_input_dims(Λ, id) - init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id); rng) + init_params!(StrideArray{T}(undef, numparam(Λ, id)), Λ, chain_input_dims(Λ, _id); rng) end """ diff --git a/test/runtests.jl b/test/runtests.jl index e854893..31288ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -235,12 +235,8 @@ InteractiveUtils.versioninfo(verbose=true) @show isapprox(g, gfdd, rtol = 1e-8) end # let g=g, sc=sc, x=x, p=p - @test iszero( - countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p), - ) - @test iszero( - countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p), - ) + @test countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p) == 0 + @test countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p) == 0 # @test iszero(@allocated(valgrad!(g, sc, x, p))) td = TurboDense{true}(tanh, static(8))