forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.jl
32 lines (24 loc) · 919 Bytes
/
data.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
using Flux
using Flux: onehot
using Flux.Data.Sentiment
using Flux.Data: Tree, leaves
traintrees = Sentiment.train()
# Get the raw labels and phrases as separate trees.
labels = map.(x -> x[1], traintrees)
phrases = map.(x -> x[2], traintrees)
# All tokens in the training set.
tokens = vcat(map(leaves, phrases)...)
# Count how many times each token appears.
freqs = Dict{String,Int}()
for t in tokens
freqs[t] = get(freqs, t, 0) + 1
end
# Replace singleton tokens with an "unknown" marker.
# This roughly cuts our "alphabet" of tokens in half.
phrases = map.(t -> get(freqs, t, 0) == 1 ? "UNK" : t, phrases)
# Our alphabet of tokens.
alphabet = unique(vcat(map(leaves, phrases)...))
# One-hot-encode our training data with respect to the alphabet.
phrases_e = map.(t -> t == nothing ? t : onehot(t, alphabet), phrases)
labels_e = map.(t -> onehot(t, 0:4), labels)
train = map.(tuple, phrases_e, labels_e)