-
-
Notifications
You must be signed in to change notification settings - Fork 612
Add EmbeddingBag
#2031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EmbeddingBag
#2031
Changes from 4 commits
eccd097
c437e2e
cbf8836
fbc9e4c
f2e7e9d
5373a41
7be2fd0
baf5d15
1db1c42
a962695
fdd1bb6
5bca3b0
6c04ecd
89db5f5
4aa753e
091fe71
a98c7a2
fcefac3
ba64701
5bc01f5
6878df8
fae30da
24dd98a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -692,3 +692,84 @@ end | |
function Base.show(io::IO, m::Embedding) | ||
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") | ||
end | ||
|
||
""" | ||
EmbeddingBag(in => out, reduction=mean; init=randn) | ||
|
||
A lookup table that stores embeddings of dimension `out` for a vocabulary of size | ||
`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The | ||
CarloLucibello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
embeddings of these are then reduced to a single embedding based on `reduction`. | ||
Typically, `reduction` is `mean`, `sum`, or `maximum`. | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
This layer is often used to store word embeddings and retrieve them using indices. | ||
The inputs can take several forms: | ||
- A scalar := single bag with a single item | ||
- A vector := single bag with multiple items | ||
- A matrix := multiple bags with multiple items (each column is a bag) | ||
- A vector of vectors: multiple bags with multiple items (each vector is a bag) | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- An input vector and offset vector: Explained below | ||
|
||
The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be | ||
CarloLucibello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
a vector of class indices and `offset` should be a vector representing offsets from the | ||
first index of `input`. The first element of `offsets` must be `0`, and `offsets` should | ||
be monotonically increasing, but the second condition is not checked. | ||
|
||
For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[0, 4, 5, 7]` | ||
is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]` | ||
|
||
# Examples | ||
```jldoctest | ||
julia> vocab_size, embed_size = 1000, 4; | ||
|
||
julia> model = Flux.EmbeddingBag(vocab_size => embed_size) | ||
EmbeddingBag(1000 => 4) # 4_000 parameters | ||
|
||
julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]]; | ||
|
||
julia> bags_mtx = [1 2; 200 5; 25 10; 789 999]; | ||
|
||
julia> model(bags) |> summary | ||
"4×2 Matrix{Float32}" | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
julia> model(bags) ≈ model(bags_mtx) | ||
true | ||
``` | ||
""" | ||
struct EmbeddingBag{F, W} | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
weight::W | ||
reduction::F | ||
end | ||
|
||
@functor EmbeddingBag | ||
|
||
EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) | ||
EmbeddingBag(weight) = EmbeddingBag(weight, mean) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) | ||
offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0.")) | ||
out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets)) | ||
start = firstindex(inputs) | ||
for i in eachindex(offsets[1:end-1]) | ||
out[:, i] = m(inputs[start:offsets[i+1]]) | ||
start = offsets[i+1]+1 | ||
end | ||
out[:, end] = m(inputs[offsets[end]+1:end]) | ||
out | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] | ||
(m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) | ||
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, julia> (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
julia> (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
julia> test(m::EmbeddingBag, bags::AbstractVector{<:AbstractVector}) = mapreduce(m, hcat, bags)
julia> test(m::EmbeddingBag, bags::AbstractMatrix) = mapreduce(m, hcat, eachcol(bags))
julia> e = Flux.EmbeddingBag(100=>64)
julia> bags = [[rand(1:100) for _ in 1:3] for _ in 1:1000]
julia> @btime e(bags);
709.630 μs (14004 allocations: 2.16 MiB)
julia> @btime test(e, bags);
14.700 ms (15935 allocations: 124.18 MiB) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If this is the hurdle, then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The really big memory cost is going to be the gradient of https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80 One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be done after this PR, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I just mean these concerns will dwarf the |
||
|
||
function (m::EmbeddingBag)(x::OneHotVector{T,L}) where {T,L} | ||
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) | ||
return m(onecold(x)) | ||
end | ||
function (m::EmbeddingBag)(x::OneHotMatrix{T,L}) where {T,L} | ||
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) | ||
return m(LinearAlgebra.Transpose(onecold(x))) | ||
end | ||
mcognetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function Base.show(io::IO, m::EmbeddingBag) | ||
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") | ||
end |
Uh oh!
There was an error while loading. Please reload this page.