Skip to content

Commit 6122c88

Browse files
committed
add unittest, minors change
1 parent 2a67384 commit 6122c88

6 files changed

+200
-0
lines changed

src/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .dbpedia import dbpediaSNN, dbpediaIF
2+
from .utils import Relation, Entitie, tnorm_loss

test/__init__.py

Whitespace-only changes.

test/dbpedia_node_prior.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"z": 0, "a": 2, "c": 1, "e": 0, "d": 1, "b": 2, "f": 0, "g": 0, "h": 0, "i": 0, "x": 3, "y": 0, "w": 3}

test/dbpedia_problem.model

98.4 KB
Binary file not shown.

test/dbpedia_test.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import unittest
2+
import os
3+
import sys
4+
from keras.layers import Input
5+
from keras.models import load_model
6+
import json
7+
from itertools import cycle
8+
import numpy as np
9+
10+
try:
11+
MODULE = os.path.dirname(os.path.realpath(__file__))
12+
except:
13+
MODULE = ""
14+
15+
sys.path.insert(0, os.path.join(MODULE, '..'))
16+
17+
from src import dbpediaSNN, dbpediaIF, Relation, Entitie, tnorm_loss
18+
from src import dbpedia
19+
20+
sys.path.pop(0)
21+
22+
23+
class mock_pyMongoCursor(list):
24+
25+
def next(self):
26+
return self[0]
27+
28+
29+
class mock_pyMongoCollection:
30+
31+
def __init__(self, data):
32+
self.inst = data
33+
34+
def find(self, *args):
35+
return self.inst[:]
36+
37+
def aggregate(self, lst):
38+
n = 0
39+
for i in lst:
40+
if '$sample' in i:
41+
n = i['$sample']['size']
42+
break
43+
res = mock_pyMongoCursor()
44+
for i, j in enumerate(cycle(self.inst)):
45+
if i == n:
46+
break
47+
res.append(j)
48+
return res
49+
50+
51+
class Test_dbpedia_isa(unittest.TestCase):
52+
53+
def setUp(self):
54+
it = [('z', 'a'), ('c', 'a'), ('e', 'c'), ('d', 'b'), ('c', 'b'), ('f', 'c'),
55+
('g', 'd'), ('h', 'd'), ('i', 'd'), ('b', 'x'), ('y', 'x'), ('a', 'w')]
56+
ents = [i for i in 'abcdefghiwxyz']
57+
58+
inn = Input(shape=(10,), name='input')
59+
net = dbpediaSNN(ents, {}, it)
60+
model1 = net(inn, False, True)
61+
model2 = net(inn, True, True)
62+
63+
self.ndp = json.load(open(os.path.join(MODULE, 'dbpedia_node_prior.json'), 'r'))
64+
self.net_ndp = net.node_prior
65+
model11 = load_model(os.path.join(MODULE, 'dbpedia_problem.model'), custom_objects={
66+
'Relation': Relation, 'Entitie': Entitie,
67+
'tnorm_loss': tnorm_loss})
68+
model22 = load_model(os.path.join(MODULE, 'dbpedia_train.model'), custom_objects={
69+
'Relation': Relation, 'Entitie': Entitie,
70+
'tnorm_loss': tnorm_loss})
71+
72+
def layers(y):
73+
return set((map(lambda x: tuple(sorted(x.name.split('_'))), y.layers)))
74+
75+
self.l1 = layers(model1)
76+
self.l2 = layers(model2)
77+
self.l11 = layers(model1)
78+
self.l22 = layers(model2)
79+
80+
def test_layers_priotiry_for_training(self):
81+
self.assertEqual(self.ndp, self.net_ndp)
82+
83+
def test_training_model_layers(self):
84+
self.assertEqual(self.l2, self.l2)
85+
86+
def test_problem_model_layers(self):
87+
self.assertEqual(self.l1, self.l11)
88+
89+
90+
class Test_dbpedia_DataInterface(unittest.TestCase):
91+
92+
def setUp(self):
93+
global dbpedia
94+
self.old_db = dbpedia.db
95+
self.old_rels = dbpedia.relations
96+
dbpedia.db = {"Language": mock_pyMongoCollection([{'instance': 'aa'}, {'instance': 'bb'}]),
97+
'Continent': mock_pyMongoCollection([{'instance': 'cc', 'Language': 'aa', 'has_millonarie': 'ee'},
98+
{'instance': 'dd', 'Language': 'bb', 'has_millonarie': 'ff'}]),
99+
'Millonarie': mock_pyMongoCollection([{'instance': 'ee', 'Language': 'aa'},
100+
{'instance': 'ff', 'Language': 'bb'}]),
101+
'Relations': mock_pyMongoCollection(
102+
[{'e1': 'Language', 'e2': 'Continent', 'relFrom': 'Continent',
103+
'relation': 'speak_language'},
104+
{'e1': 'Millonarie', 'e2': 'Continent', 'relFrom': 'Continent',
105+
'relation': 'has_millonarie'}])}
106+
dbpedia.relations = dbpedia.db['Relations']
107+
self.dif = dbpediaIF(
108+
["Language", 'Continent', 'Millonarie'], ['speak_language', 'has_millonarie'])
109+
110+
def tearDown(self):
111+
global dbpedia
112+
dbpedia.db = self.old_db
113+
dbpedia.relations = self.old_rels
114+
115+
def test_sample_entitie_name(self):
116+
data, res = self.dif.sample_entitie_name("Language", 2)
117+
datag = np.array([[1, 0, 0, 0, 0, 0],
118+
[0, 1, 0, 0, 0, 0]])
119+
resg = np.array([[0., 1., 0., 0., 0.],
120+
[0., 1., 0., 0., 0.]])
121+
self.assertLessEqual(np.abs(data-datag).flatten().sum(), 5e-16)
122+
self.assertLessEqual(np.abs(res-resg).flatten().sum(), 5e-16)
123+
124+
data, res = self.dif.sample_entitie_name("Continent", 2)
125+
datag = np.array([[0, 0, 1, 0, 0, 0],
126+
[0, 0, 0, 1, 0, 0]])
127+
resg = np.array([[1., 0., 0., 0., 0.],
128+
[1., 0., 0., 0., 0.]])
129+
self.assertLessEqual(np.abs(data-datag).flatten().sum(), 5e-16)
130+
self.assertLessEqual(np.abs(res - resg).flatten().sum(), 5e-16)
131+
132+
def test_sample_entities_name(self):
133+
for i in range(4):
134+
with self.subTest(f'Entities random sampling, iteration {i}'):
135+
data, res = self.dif.sample_entities_name(['Language', 'Continent'], 2)
136+
lang = data[0][0] != 0 or data[0][1] != 0
137+
cont = data[0][2] != 0 or data[0][3] != 0
138+
self.assertTrue(cont ^ lang, 'Bad enncoding of entitie')
139+
self.assertTrue(cont == res[0][0] and lang == res[0][1], 'Bad enncoding of result')
140+
141+
lang = data[1][0] != 0 or data[1][1] != 0
142+
cont = data[1][2] != 0 or data[1][3] != 0
143+
self.assertTrue(cont ^ lang, 'Bad enncoding of entitie')
144+
self.assertTrue(cont == res[1][0] and lang ==
145+
res[1][1], 'Bad enncoding of result')
146+
147+
self.assertTrue(res[0][2] == 0 and res[1][2]
148+
== 0, 'Bad enncoding of result')
149+
self.assertTrue(res[0][3] == 0 and res[1][3]
150+
== 0, 'Bad enncoding of result')
151+
self.assertTrue(res[0][4] == 0 and res[1][4]
152+
== 0, 'Bad enncoding of result')
153+
154+
def test_sample_entities(self):
155+
for i in range(4):
156+
with self.subTest(f'Entities random sampling, iteration {i}'):
157+
data, res = self.dif.sample_entities(2)
158+
lang = data[0][0] != 0 or data[0][1] != 0
159+
cont = data[0][2] != 0 or data[0][3] != 0
160+
mill = data[0][4] != 0 or data[0][5] != 0
161+
self.assertTrue(cont ^ lang ^ mill, 'Bad enncoding of entitie')
162+
self.assertTrue(
163+
cont == res[0][0] and lang == res[0][1] and mill == res[0][2], 'Bad enncoding of result')
164+
165+
lang = data[1][0] != 0 or data[1][1] != 0
166+
cont = data[1][2] != 0 or data[1][3] != 0
167+
mill = data[1][4] != 0 or data[1][5] != 0
168+
self.assertTrue(cont ^ lang ^ mill, 'Bad enncoding of entitie')
169+
self.assertTrue(
170+
cont == res[1][0] and lang == res[1][1] and mill == res[1][2], 'Bad enncoding of result')
171+
172+
self.assertTrue(res[0][3] == 0 and res[1][3]
173+
== 0, 'Bad enncoding of result')
174+
self.assertTrue(res[0][4] == 0 and res[1][4]
175+
== 0, 'Bad enncoding of result')
176+
177+
def test_sample_relation_name(self):
178+
data, res = self.dif.sample_relation_name("speak_language", 2)
179+
datag = np.array([[1, 0, 1, 0, 0, 0],
180+
[0, 1, 0, 1, 0, 0]])
181+
resg = np.array([[1., 1., 0., 0., 1.],
182+
[1., 1., 0., 0., 1.]])
183+
self.assertLessEqual(np.abs(data-datag).flatten().sum(), 5e-16)
184+
self.assertLessEqual(np.abs(res-resg).flatten().sum(), 5e-16)
185+
186+
data, res = self.dif.sample_relation_name("has_millonarie", 2)
187+
datag = np.array([[1, 0, 1, 0, 0, 0],
188+
[0, 1, 0, 1, 0, 0]])
189+
resg = np.array([[1., 1., 0., 1., 0.],
190+
[1., 1., 0., 1., 0.]])
191+
self.assertLessEqual(np.abs(data-datag).flatten().sum(), 5e-16)
192+
self.assertLessEqual(np.abs(res - resg).flatten().sum(), 5e-16)
193+
194+
195+
196+
197+
if __name__ == '__main__':
198+
unittest.main()

test/dbpedia_train.model

123 KB
Binary file not shown.

0 commit comments

Comments
 (0)