Skip to content

Commit e1647dc

Browse files
committed
Refactor of SNN
1 parent 67905f3 commit e1647dc

14 files changed

+755
-25
lines changed

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ services:
55
image: matcomuh/ml:cpu
66
volumes:
77
- ./:/ml
8+
- ./snn:/usr/lib/python3/dist-packages/snn
89
ports:
910
- '8888:8888'
1011
container_name: ml

notebooks/sample-snn.ipynb

+439
Large diffs are not rendered by default.

notebooks/test-snn.ipynb

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"Using TensorFlow backend.\n"
13+
]
14+
}
15+
],
16+
"source": [
17+
"from keras.layers import Input\n",
18+
"from keras.models import Model"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": 2,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"from snn.base_model import SNN\n",
28+
"from snn import Entity, Relation"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 3,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"Thing = Entity(\"Thing\")\n",
38+
"Person = Entity(\"Person\", Thing)\n",
39+
"Object = Entity(\"Object\", Thing)\n",
40+
"\n",
41+
"hold = Relation(\"hold\", Person, Object)\n",
42+
"look = Relation(\"look\", Person, Object)"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 4,
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"snn = SNN(entities=[Thing, Person, Object], relations=[hold, look])"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 5,
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"x = Input(shape=(100,))\n",
61+
"y = snn(x, compiled=True)"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 7,
67+
"metadata": {},
68+
"outputs": [
69+
{
70+
"name": "stdout",
71+
"output_type": "stream",
72+
"text": [
73+
"__________________________________________________________________________________________________\n",
74+
"Layer (type) Output Shape Param # Connected to \n",
75+
"==================================================================================================\n",
76+
"input_1 (InputLayer) (None, 100) 0 \n",
77+
"__________________________________________________________________________________________________\n",
78+
"Object (Entitie) (None, 32) 4288 input_1[0][0] \n",
79+
"__________________________________________________________________________________________________\n",
80+
"Person (Entitie) (None, 32) 4288 input_1[0][0] \n",
81+
"__________________________________________________________________________________________________\n",
82+
"max_Person_Object (Maximum) (None, 32) 0 Person[0][0] \n",
83+
" Object[0][0] \n",
84+
"__________________________________________________________________________________________________\n",
85+
"Person_Object_isa_Thing (Dense) (None, 32) 1056 max_Person_Object[0][0] \n",
86+
"__________________________________________________________________________________________________\n",
87+
"Thing (Entitie) (None, 32) 2112 Person_Object_isa_Thing[0][0] \n",
88+
"__________________________________________________________________________________________________\n",
89+
"hold (Relation) (None, 32) 2112 Person[0][0] \n",
90+
" Object[0][0] \n",
91+
"__________________________________________________________________________________________________\n",
92+
"look (Relation) (None, 32) 2112 Person[0][0] \n",
93+
" Object[0][0] \n",
94+
"__________________________________________________________________________________________________\n",
95+
"out_embeding (Concatenate) (None, 160) 0 Object[0][0] \n",
96+
" Person[0][0] \n",
97+
" Thing[0][0] \n",
98+
" hold[0][0] \n",
99+
" look[0][0] \n",
100+
"==================================================================================================\n",
101+
"Total params: 15,968\n",
102+
"Trainable params: 15,968\n",
103+
"Non-trainable params: 0\n",
104+
"__________________________________________________________________________________________________\n"
105+
]
106+
}
107+
],
108+
"source": [
109+
"y.summary()"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": 9,
115+
"metadata": {},
116+
"outputs": [
117+
{
118+
"data": {
119+
"text/plain": [
120+
"(None, 160)"
121+
]
122+
},
123+
"execution_count": 9,
124+
"metadata": {},
125+
"output_type": "execute_result"
126+
}
127+
],
128+
"source": [
129+
"y.output_shape"
130+
]
131+
}
132+
],
133+
"metadata": {
134+
"kernelspec": {
135+
"display_name": "Python 3",
136+
"language": "python",
137+
"name": "python3"
138+
},
139+
"language_info": {
140+
"codemirror_mode": {
141+
"name": "ipython",
142+
"version": 3
143+
},
144+
"file_extension": ".py",
145+
"mimetype": "text/x-python",
146+
"name": "python",
147+
"nbconvert_exporter": "python",
148+
"pygments_lexer": "ipython3",
149+
"version": "3.5.2"
150+
}
151+
},
152+
"nbformat": 4,
153+
"nbformat_minor": 2
154+
}
File renamed without changes.

snn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base import SNN, Entity, Relation

snn/base.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from keras.models import Model
2+
from keras.layers import Input, Dense, maximum, concatenate
3+
from keras.utils.vis_utils import model_to_dot
4+
5+
from .utils import EntityLayer, RelationLayer
6+
7+
8+
class SNN:
9+
def __init__(self, entities, relations, entity_shape=32, relation_shape=64):
10+
self.entities = entities
11+
self.relations = relations
12+
self.entity_shape = entity_shape
13+
self.relation_shape = relation_shape
14+
15+
def __call__(self, x):
16+
self.input_ = x
17+
self.entity_capsules_ = self._build_entities(x)
18+
self.relation_capsules_ = self._build_relations()
19+
self.outputs_, self.indicators_, self.representations_ = self._build_outputs()
20+
return self.representations_
21+
22+
def build(self):
23+
return Model(inputs=self.input_, outputs=self.indicators_)
24+
25+
def _build_entities(self, x):
26+
entity_capsules = {}
27+
28+
for e in toposort(self.entities):
29+
entity_capsules[e] = self._build_entity_capsule(x, e, entity_capsules)
30+
31+
return entity_capsules
32+
33+
def _build_entity_capsule(self, x, entity, capsules):
34+
children = [capsules[c] for c in entity.children]
35+
36+
if children:
37+
inputs = maximum(children, name="Max-%s" % entity.name)
38+
else:
39+
inputs = x
40+
41+
return EntityLayer(self.entity_shape, name=entity.name)(inputs)
42+
43+
def _build_relations(self):
44+
relation_capsules = {}
45+
46+
for r in self.relations:
47+
relation_capsules[r] = self._build_relation_capsule(r)
48+
49+
return relation_capsules
50+
51+
def _build_relation_capsule(self, relation):
52+
src = self.entity_capsules_[relation.src]
53+
dst = self.entity_capsules_[relation.dst]
54+
55+
return RelationLayer(self.relation_shape, name=relation.label)([src, dst])
56+
57+
def _build_outputs(self):
58+
outputs = {}
59+
outputs_indicators = []
60+
outputs_concat = []
61+
62+
for e in self.entities:
63+
outputs[e] = Dense(units=1, activation='sigmoid', name="Indicator-%s" % e.name)(self.entity_capsules_[e])
64+
outputs_indicators.append(outputs[e])
65+
outputs_concat.append(self.entity_capsules_[e])
66+
67+
for r in self.relations:
68+
outputs[r] = Dense(units=1, activation='sigmoid', name="Indicator-%s" % r.label)(self.relation_capsules_[r])
69+
outputs_indicators.append(outputs[r])
70+
outputs_concat.append(self.relation_capsules_[r])
71+
72+
indicators = concatenate(outputs_indicators, name="Indicators")
73+
concat = concatenate(outputs_concat, name="Representations")
74+
75+
return outputs, indicators, concat
76+
77+
78+
def toposort(entities):
79+
visited = set()
80+
path = []
81+
82+
def visit(e):
83+
if e in visited:
84+
return
85+
86+
for c in e.children:
87+
visit(c)
88+
89+
visited.add(e)
90+
path.append(e)
91+
92+
for e in entities:
93+
visit(e)
94+
95+
return path
96+
97+
98+
class Entity:
99+
def __init__(self, name, *parents):
100+
self.name = name
101+
self.parents = parents
102+
self.children = []
103+
104+
for p in self.parents:
105+
p.children.append(self)
106+
107+
def __repr__(self):
108+
return "<%s>" % self.name
109+
110+
class Relation:
111+
def __init__(self, label, src:Entity, dst:Entity):
112+
self.label = label
113+
self.src = src
114+
self.dst = dst
115+
116+
def __repr__(self):
117+
return "%s(%s , %s)" % (self.label, self.src, self.dst)

src/base_model.py renamed to snn/base_model.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55
# from .utils import tnorm, tnorm_output_shape, isLayer, isModel
66
# except ImportError:
77
# from utils import tnorm, tnorm_output_shape, isLayer, isModel
8-
from utils import tnorm, tnorm_output_shape, isLayer, isModel
8+
from .utils import tnorm, tnorm_output_shape, isLayer, isModel, tnorm_loss, bin_acc
9+
from .utils import Entitie as EntityLayer
10+
from .utils import Relation as RelationLayer
11+
from .base import Entity, Relation
912

13+
from keras.models import Model
1014
from keras.layers import Lambda, Maximum, Dense, Concatenate
1115

1216

13-
class SNN(metaclass=abc.ABCMeta):
17+
Entities = List[Entity]
18+
Relations = List[Relation]
19+
1420

15-
def __init__(self, entities: List[str],
16-
relations: Dict[str, List[Tuple[str]]], isar: List[Tuple[str]]):
21+
class SNN(metaclass=abc.ABCMeta):
22+
def __init__(self, entities: Entities, relations:Relations):
1723
""" redes semanticas
1824
entities: lista de strings
1925
relations: diccionario de: llave relacion, valor lista de tuplas de
@@ -24,6 +30,10 @@ def __init__(self, entities: List[str],
2430
isar: lista de tuplas de pares de entidades relacionadas por relaciones
2531
parecidas a "is a" de la forma (hijo, padre)
2632
"""
33+
isar = [(e.name, p.name) for e in entities for p in e.parents]
34+
entities = [e.name for e in entities]
35+
relations = {r.label: [(r.src.name, r.dst.name)] for r in relations}
36+
2737
eet = set(entities)
2838
# check for duplicate entities
2939
assert len(eet) == len(entities), "Exist duplicated entities"
@@ -154,7 +164,7 @@ def _build_model(self, inputt):
154164
rels[rel+str(n)] = self.relation_capsule(
155165
rel+'_'+str(n), [ents[e1], ents[e2]])
156166

157-
def __call__(self, inputt, train=False, compilee=False):
167+
def __call__(self, inputt, train=False, compiled=False):
158168
self._build_model(inputt)
159169
ents = self.ents
160170
rels = self.rels
@@ -167,12 +177,12 @@ def __call__(self, inputt, train=False, compilee=False):
167177
out.append(Lambda(tnorm,
168178
output_shape=tnorm_output_shape, name=i+'-out')(rels[i]))
169179
outt = Concatenate(name='out_embeding')(out)
170-
if compilee:
180+
if compiled:
171181
return self.wrap_compile(inputt, outt)
172182
return outt
173183
outt = Concatenate(name='out_embeding')([ents[i] for i in self.entities] +
174184
[rels[j] for j in sorted(rels.keys())])
175-
if compilee:
185+
if compiled:
176186
return self.wrap_compile(inputt, outt)
177187
return outt
178188

@@ -183,33 +193,32 @@ def pretrain(self):
183193
def wrap_compile(self, inn, outt):
184194
return self._compile(inn, outt)
185195

186-
@abc.abstractmethod
187196
def _compile(self, inn, outt):
188-
raise NotImplementedError()
197+
model = Model(inputs=inn, outputs=outt)
198+
model.compile(optimizer='RMSprop',
199+
loss=tnorm_loss, metrics=[bin_acc])
200+
return model
189201

190-
@abc.abstractmethod
191202
def entitie_capsule(self, name: str, inputt):
192203
"""
193204
name: name of the entitie
194205
imput: is a keras symbolic tensor or input layer or a class that
195206
implement keras.Layer interface
196207
"""
197-
raise NotImplementedError()
208+
return EntityLayer(32, name=name)(inputt)
198209

199-
@abc.abstractmethod
200210
def relation_capsule(self, name: str, inputs: List):
201211
"""
202212
name: name of the entitie
203213
imputs: is a list of keras symbolic tensors or input layers or a
204214
classes that implement keras.Layer interface
205215
"""
206-
raise NotImplementedError()
216+
return RelationLayer(32, name=name)(inputs)
207217

208-
@abc.abstractmethod
209218
def isa_capsule(self, name: str, inputt):
210219
"""
211220
name: name of the entitie
212221
imput: is a keras symbolic tensor or input layer or a class that
213222
implement keras.Layer interface
214223
"""
215-
raise NotImplementedError()
224+
return Dense(32, name=name)(inputt)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)