Skip to content

Commit 8d1ecf5

Browse files
authored
Merge pull request #116 from PumasAI/initparamsbangcorrectargorder
fix arg order of `init_params!`
2 parents 906b70b + aec45d7 commit 8d1ecf5

File tree

5 files changed

+15
-18
lines changed

5 files changed

+15
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleChains"
22
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
33
authors = ["Chris Elrod <[email protected]> and contributors"]
4-
version = "0.3.4"
4+
version = "0.4.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

examples/mnist_lenet.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,19 @@ G = SimpleChains.alloc_threaded_grad(lenetloss);
4848
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
4949
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)
5050

51-
# SimpleChains.init_params!(lenet, p);
51+
# SimpleChains.init_params!(p, lenet);
5252
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10);
5353
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
5454
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)
5555

5656

5757

5858
# lenet.memory .= 0;
59-
SimpleChains.init_params!(lenet, p);
59+
SimpleChains.init_params!(p, lenet);
6060
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10);
6161
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
6262
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)
63-
SimpleChains.init_params!(lenet, p);
63+
SimpleChains.init_params!(p, lenet);
6464
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10);
6565
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
6666
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)
@@ -75,11 +75,11 @@ lenetloss.memory .= 0x00;
7575
@time valgrad!(g1, lenetloss, xtrain, p)
7676
g0 == g1
7777
lenet.memory .= 0;
78-
SimpleChains.init_params!(lenet, p);
78+
SimpleChains.init_params!(p, lenet);
7979
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10);
8080
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
8181
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)
82-
SimpleChains.init_params!(lenet, p);
82+
SimpleChains.init_params!(p, lenet);
8383
@time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10);
8484
SimpleChains.accuracy_and_loss(lenetloss, xtrain, p),
8585
SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p)

src/penalty.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ end
4747
function init_params::AbstractPenalty, ::Type{T}; rng::AbstractRNG=local_rng()) where {T}
4848
init_params(getchain(Λ), nothing, T; rng)
4949
end
50-
function init_params!::AbstractPenalty, x, id = nothing; rng::AbstractRNG=local_rng())
51-
init_params!(getchain(Λ), x, id; rng)
50+
function init_params!(x, Λ::AbstractPenalty, id = nothing; rng::AbstractRNG=local_rng())
51+
init_params!(x, getchain(Λ), id; rng)
5252
end
5353

5454
target(c::AbstractPenalty) = target(getchain(c))

src/simple_chain.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23
struct InputDimUnknown end
34
const InputDim = Union{InputDimUnknown,Tuple{Vararg{Integer}}}
45

@@ -387,8 +388,8 @@ end
387388
Randomly initializes parameter vector `p` with input dim `id`. Input dim does not need to be specified if these were provided to the chain object itself.
388389
See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions.
389390
"""
390-
function init_params!(
391-
chn::SimpleChain, x::AbstractVector, id = nothing; rng::AbstractRNG = local_rng()
391+
@inline function init_params!(
392+
x::AbstractVector, chn::SimpleChain, id = nothing; rng::AbstractRNG = local_rng()
392393
)
393394
GC.@preserve x _init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng)
394395
return x
@@ -398,14 +399,14 @@ function _init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG)
398399
_init_params!(Base.tail(layers), p, od, rng)
399400
end
400401
_init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing
401-
function init_params(
402+
@inline function init_params(
402403
Λ::SimpleChain,
403404
id::Union{Nothing,InputDim} = nothing,
404405
::Type{T} = Float32;
405406
rng::AbstractRNG=local_rng()
406407
) where {T}
407408
_id = chain_input_dims(Λ, id)
408-
init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id); rng)
409+
init_params!(StrideArray{T}(undef, numparam(Λ, id)), Λ, chain_input_dims(Λ, _id); rng)
409410
end
410411

411412
"""

test/runtests.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,8 @@ InteractiveUtils.versioninfo(verbose=true)
235235
@show isapprox(g, gfdd, rtol = 1e-8)
236236
end
237237
# let g=g, sc=sc, x=x, p=p
238-
@test iszero(
239-
countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p),
240-
)
241-
@test iszero(
242-
countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p),
243-
)
238+
@test countallocations!(g, FrontLastPenalty(sc, L2Penalty(2.3), NoPenalty()), x, p) == 0
239+
@test countallocations!(g, FrontLastPenalty(scd, L2Penalty(2.3), L1Penalty(0.45)), x, p) == 0
244240
# @test iszero(@allocated(valgrad!(g, sc, x, p)))
245241

246242
td = TurboDense{true}(tanh, static(8))

0 commit comments

Comments
 (0)