Skip to content

Commit

Permalink
Merge pull request #85 from PumasAI/biasandweightviews
Browse files Browse the repository at this point in the history
Add bias and weight views
  • Loading branch information
korsbo authored Jun 9, 2022
2 parents 94e94e5 + 94ff7f2 commit d1ffecc
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
21 changes: 0 additions & 21 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,6 @@ struct SimpleChain{N,I<:InputDim,L<:Tuple{Vararg{Any,N}}}
memory::Vector{UInt8}
end

#=
input_dims(_) = nothing
function _check_input_dims(x, i)
d = input_dims(x)
d === nothing || d == i || throw(ArgumentError("Input size of one layer did not match the next."))
end
function _input_dims(t::Tuple{L,Vararg}) where {L}
l = first(t)
d = input_dims(l)
d === nothing ? _input_dims(Base.tail(t)) : d
end
_verify_chain(::Tuple{}, _) = nothing
function _verify_chain(layers::Tuple{L,Vararg}, inputdim = _input_dims(layers)) where {L}
l = first(layers)
_check_input_dims(l, inputdim)
d = output_size(Val(Float32), l, (inputdim,))[2][1]
_verify_chain(Base.tail(layers), d)
end
=#
chain_input_dims(c::SimpleChain) = c.inputdim

SimpleChain(input_dim::Integer, l::Vararg) = SimpleChain((input_dim,), l, UInt8[])
Expand Down
51 changes: 46 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,56 @@ Returns a tuple of the parameters of the SimpleChain `sc`, as a view of the para
"""
function params(sc::SimpleChain, p::AbstractVector, inputdim = nothing)
@unpack layers = sc
A = _params(layers, pointer(p), chain_input_dims(sc, inputdim))
A = _walk_chain(Val{:param}(), layers, pointer(p), chain_input_dims(sc, inputdim))
_add_memory(A, p)
end
_params(::Tuple{}, _, __) = ()
function _params(layers, p, inputdim)
A, p, outputdim = _getparams(first(layers), p, inputdim)
B = _params(Base.tail(layers), p, outputdim)
"""
weights(sc::SimpleChain, p::AbstractVector, inputdim = nothing)
Returns a tuple of the weights (parameters other than biases) of the SimpleChain `sc`, as a view of the parameter vector `p`.
"""
function weights(sc::SimpleChain, p::AbstractVector, inputdim = nothing)
@unpack layers = sc
A = _walk_chain(Val{:weight}(), layers, pointer(p), chain_input_dims(sc, inputdim))
_add_memory(A, p)
end
"""
biases(sc::SimpleChain, p::AbstractVector, inputdim = nothing)
Returns a tuple of the biases of the SimpleChain `sc`, as a view of the parameter vector `p`.
"""
function biases(sc::SimpleChain, p::AbstractVector, inputdim = nothing)
@unpack layers = sc
A = _walk_chain(Val{:bias}(), layers, pointer(p), chain_input_dims(sc, inputdim))
_add_memory(A, p)
end

# definitions that happen to be right in most cases to save up
# from implementing too much
_get(::Val{:param}, _, x) = x
_get(::Val{:weight}, _, ::Nothing) = nothing
_get(::Val{:weight}, _, x) = x
_get(::Val{:weight}, _, x::Tuple{A,B}) where {A,B} = first(x)
_get(::Val{:bias}, _, ::Nothing) = nothing
_get(::Val{:bias}, _, x) = nothing
_get(::Val{:bias}, _, x::Tuple{A,B}) where {A,B} = last(x)
@inline function _getparams(f::F, layer, p, inputdim) where {F}
A, p, outputdim = _getparams(layer, p, inputdim)
_get(f, layer, A), p, outputdim
end
# TODO: support nesting simple chains; below definition should enable recursive params
#=
function _getparams(f::F, layer::Union{AbstractPenalty,SimpleChain}, p, inputdim) where {F}
_walk_chain(f, layer, p, inputdim)
end
=#
_walk_chain(___, ::Tuple{}, _, __) = ()
function _walk_chain(f::F, layers, p, inputdim) where {F}
A, p, outputdim = _getparams(f, first(layers), p, inputdim)
B = _walk_chain(f, Base.tail(layers), p, outputdim)
(A, B...)
end

_add_memory(A::PtrArray, p) = StrideArray(A, p)
_add_memory(::Tuple{}, _) = ()
function _add_memory(t::Tuple, p)
Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,17 @@ InteractiveUtils.versioninfo(verbose=true)
),
)
p = SimpleChains.init_params(sc);
n0, (W1, b1), n2, W3 = SimpleChains.params(sc,p)
n0, (W1, b1), n2, W3 = SimpleChains.params(sc,p);
@test n0 === n2 === nothing
@test W1 == reshape(view(p,1:24*8),(8,24))
@test b1 == view(p,24*8+1:25*8)
@test W3 == reshape(@view(p[25*8+1:end]),(2,8))
n01, W11, n21, W31 = SimpleChains.weights(sc,p);
n02, b12, n22, n3 = SimpleChains.biases(sc,p);
@test n01 === n21 === n02 === n22 === n3
@test W11 === W1
@test W31 === W3
@test b12 === b1
end
end
# TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped
Expand Down

2 comments on commit d1ffecc

@korsbo
Copy link
Member Author

@korsbo korsbo commented on d1ffecc Jun 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62015

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.9 -m "<description of version>" d1ffecc3e501f9dd71cbca7c256166f78b9a5e13
git push origin v0.2.9

Please sign in to comment.