forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1-model.jl
74 lines (53 loc) · 1.68 KB
/
1-model.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Based on https://arxiv.org/abs/1409.0473
using Flux: flip, crossentropy, reset!, throttle
include("0-data.jl")
Nin = length(alphabet)
Nh = 30 # size of hidden layer
# A recurrent model which takes a token and returns a context-dependent
# annotation.
forward = LSTM(Nin, Nh÷2)
backward = LSTM(Nin, Nh÷2)
encode(tokens) = vcat.(forward.(tokens), flip(backward, tokens))
alignnet = Dense(2Nh, 1)
align(s, t) = alignnet(vcat(t, s .* trues(1, size(t, 2))))
# A recurrent model which takes a sequence of annotations, attends, and returns
# a predicted output token.
recur = LSTM(Nh+length(phones), Nh)
toalpha = Dense(Nh, length(phones))
function asoftmax(xs)
xs = [exp.(x) for x in xs]
s = sum(xs)
return [x ./ s for x in xs]
end
function decode1(tokens, phone)
weights = asoftmax([align(recur.state[2], t) for t in tokens])
context = sum(map((a, b) -> a .* b, weights, tokens))
y = recur(vcat(Float32.(phone), context))
return softmax(toalpha(y))
end
decode(tokens, phones) = [decode1(tokens, phone) for phone in phones]
# The full model
state = (forward, backward, alignnet, recur, toalpha)
function model(x, y)
ŷ = decode(encode(x), y)
reset!(state)
return ŷ
end
loss(x, yo, y) = sum(crossentropy.(model(x, yo), y))
evalcb = () -> @show loss(data[500]...)
opt = ADAM()
Flux.train!(loss, params(state), data, opt, cb = throttle(evalcb, 10))
# Prediction
using StatsBase: wsample
function predict(s)
ts = encode(tokenise(s, alphabet))
ps = Any[:start]
for i = 1:50
dist = decode1(ts, onehot(ps[end], phones))
next = wsample(phones, vec(Tracker.data(dist)))
next == :end && break
push!(ps, next)
end
return ps[2:end]
end
predict("PHYLOGENY")