forked from philipperemy/keras-tcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsequential.py
64 lines (52 loc) · 1.77 KB
/
sequential.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
#Trains a TCN on the IMDB sentiment classification task.
Output after 1 epochs on CPU: ~0.8611
Time per epoch on CPU (Core i7): ~64s.
Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
"""
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Dense, Dropout, Embedding
from tensorflow.keras.preprocessing import sequence
from tcn import TCN
max_features = 20000
# cut texts after this number of words
# (among top max_features most common words)
maxlen = 100
batch_size = 32
print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')
print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
y_train = np.array(y_train)
y_test = np.array(y_test)
model = Sequential()
model.add(Embedding(max_features, 128, input_shape=(maxlen,)))
model.add(TCN(
nb_filters=64,
kernel_size=6,
dilations=[1, 2, 4, 8, 16, 32, 64]
))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
model.summary()
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
class TestCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
print(logs)
acc_key = 'val_accuracy' if 'val_accuracy' in logs else 'val_acc'
assert logs[acc_key] > 0.78
print('Train...')
model.fit(
x_train, y_train,
batch_size=batch_size,
validation_data=(x_test, y_test),
callbacks=[TestCallback()]
)