-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerator.py
77 lines (63 loc) · 1.95 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import random
import pickle
num_train = 60000
num_val = 10000
num_test = 10000
step_num = 4
elem_num = 26 + 10 + 1
x_train = np.zeros([num_train, step_num * 2 + 3, elem_num], dtype=np.float32)
x_val = np.zeros([num_val, step_num * 2 + 3, elem_num], dtype=np.float32)
x_test = np.zeros([num_test, step_num * 2 + 3, elem_num], dtype=np.float32)
y_train = np.zeros([num_train, elem_num], dtype=np.float32)
y_val = np.zeros([num_val, elem_num], dtype=np.float32)
y_test = np.zeros([num_test, elem_num], dtype=np.float32)
def get_one_hot(c):
a = np.zeros([elem_num])
if ord('a') <= ord(c) <= ord('z'):
a[ord(c) - ord('a')] = 1
elif ord('0') <= ord(c) <= ord('9'):
a[ord(c) - ord('0') + 26] = 1
else:
a[-1] = 1
return a
def generate_one():
a = np.zeros([step_num * 2 + 3, elem_num])
d = {}
st = ''
for i in range(0, step_num):
c = random.randint(0, 25)
while d.has_key(c):
c = random.randint(0, 25)
b = random.randint(0, 9)
d[c] = b
s, t = chr(c + ord('a')), chr(b + ord('0'))
st += s + t
a[i*2] = get_one_hot(s)
a[i*2+1] = get_one_hot(t)
s = random.choice(d.keys())
t = chr(s + ord('a'))
r = chr(d[s] + ord('0'))
a[step_num * 2] = get_one_hot('?')
a[step_num * 2 + 1] = get_one_hot('?')
a[step_num * 2 + 2] = get_one_hot(t)
st += '??' + t + r
e = get_one_hot(r)
return a, e
if __name__ == '__main__':
for i in range(0, num_train):
x_train[i], y_train[i] = generate_one()
for i in range(0, num_test):
x_test[i], y_test[i] = generate_one()
for i in range(0, num_val):
x_val[i], y_val[i] = generate_one()
d = {
'x_train': x_train,
'x_test': x_test,
'x_val': x_val,
'y_train': y_train,
'y_test': y_test,
'y_val': y_val
}
with open('associative-retrieval.pkl', 'wb') as f:
pickle.dump(d, f, protocol=2)