diff --git a/Project.toml b/Project.toml index d515aaa7..08c26cc9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleChains" uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" authors = ["Chris Elrod and contributors"] -version = "0.3.4" +version = "0.4.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/examples/mnist_lenet.jl b/examples/mnist_lenet.jl index 8ef8ba3b..2031a867 100644 --- a/examples/mnist_lenet.jl +++ b/examples/mnist_lenet.jl @@ -48,7 +48,7 @@ G = SimpleChains.alloc_threaded_grad(lenetloss); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -# SimpleChains.init_params!(lenet, p); +# SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) @@ -56,11 +56,11 @@ SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) # lenet.memory .= 0; -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) @@ -75,11 +75,11 @@ lenetloss.memory .= 0x00; @time valgrad!(g1, lenetloss, xtrain, p) g0 == g1 lenet.memory .= 0; -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -SimpleChains.init_params!(lenet, p); +SimpleChains.init_params!(p, lenet); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) diff --git a/src/penalty.jl b/src/penalty.jl index fd0755ef..16e8fc8f 100644 --- a/src/penalty.jl +++ b/src/penalty.jl @@ -47,8 +47,8 @@ end function init_params(Λ::AbstractPenalty, ::Type{T}; rng::AbstractRNG=local_rng()) where {T} init_params(getchain(Λ), nothing, T; rng) end -function init_params!(Λ::AbstractPenalty, x, id = nothing; rng::AbstractRNG=local_rng()) - init_params!(getchain(Λ), x, id; rng) +function init_params!(x, Λ::AbstractPenalty, id = nothing; rng::AbstractRNG=local_rng()) + init_params!(x, getchain(Λ), id; rng) end target(c::AbstractPenalty) = target(getchain(c)) diff --git a/src/simple_chain.jl b/src/simple_chain.jl index 6a7103bd..0e04117c 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -1,4 +1,5 @@ + struct InputDimUnknown end const InputDim = Union{InputDimUnknown,Tuple{Vararg{Integer}}} @@ -387,8 +388,8 @@ end Randomly initializes parameter vector `p` with input dim `id`. Input dim does not need to be specified if these were provided to the chain object itself. See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ -function init_params!( - chn::SimpleChain, x::AbstractVector, id = nothing; rng::AbstractRNG = local_rng() +@inline function init_params!( + x::AbstractVector, chn::SimpleChain, id = nothing; rng::AbstractRNG = local_rng() ) GC.@preserve x _init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng) return x @@ -398,14 +399,14 @@ function _init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG) _init_params!(Base.tail(layers), p, od, rng) end _init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing -function init_params( +@inline function init_params( Λ::SimpleChain, id::Union{Nothing,InputDim} = nothing, ::Type{T} = Float32; rng::AbstractRNG=local_rng() ) where {T} _id = chain_input_dims(Λ, id) - init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id); rng) + init_params!(StrideArray{T}(undef, numparam(Λ, id)), Λ, chain_input_dims(Λ, _id); rng) end """ diff --git a/test/runtests.jl b/test/runtests.jl index e8548930..31288ad8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -235,12 +235,8 @@ InteractiveUtils.versioninfo(verbose=true) @show isapprox(g, gfdd, rtol = 1e-8) end # let g=g, sc=sc, x=x, p=p - @test iszero( - countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p), - ) - @test iszero( - countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p), - ) + @test countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p) == 0 + @test countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p) == 0 # @test iszero(@allocated(valgrad!(g, sc, x, p))) td = TurboDense{true}(tanh, static(8))