Skip to content

Commit 9747085

Browse files
author
Rob
committed
FINAL
0 parents  commit 9747085

37 files changed

+707435
-0
lines changed

.gitignore

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.ipynb_checkpoints/*
2+
.idea/*
3+
.DS_Store*
4+
Archive
5+
*.pyc
6+
.~lock.*
7+
lm/libri-timit-lm.arpa
8+
lm/libri-timit-lm.klm
9+
lm/timit-lm.arpa
10+
lm/timit-lm.klm
11+
*/.ipynb_checkpoints
12+
13+
kDS.sh
14+
kDS-LOCAL.sh
15+
16+
data/*
17+
!enron*
18+
!e00*

Android/README.md

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
## Tensorflow Guide
2+
3+
4+
1. Download ALL of tensorflow (you need to compile TF for your mobile)
5+
`git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git`
6+
7+
Important once downloaded change the branch to the one you are using for the model.
8+
e.g. `git checkout r1.1`
9+
10+
2. Download Android SDK (comes with studio) and NDK:
11+
Link
12+
Link
13+
14+
3. Install Bazel from here https://bazel.build/versions/master/docs/install.html
15+
16+
17+
4. Edit workspace file at the root of the tensorflow direction you grabbed from git
18+
19+
Uncomment the lines about SDK/NDK
20+
21+
android_sdk_repository
22+
android_ndk_repository
23+
24+
and fill in the correct paths to those libraries.
25+
26+
27+
5. Build .so files
28+
29+
`bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
30+
--crosstool_top=//external:android/crosstool \
31+
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
32+
--cpu=armeabi-v7a`
33+
34+
the file is bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so make a note of this
35+
36+
<!-- 6. Build java file -->
37+
38+
These need to be copied into the TENSORFLOW APP
39+
40+
41+
42+
43+
44+

Android/convert_keras_to_android.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# this file is based on https://github.com/amir-abdi/keras_to_tensorflow/
2+
3+
import os
4+
import os.path as osp
5+
6+
from keras import backend as K
7+
from keras.models import load_model
8+
9+
import tensorflow as tf
10+
from tensorflow.python.framework import graph_util
11+
from tensorflow.python.framework import graph_io
12+
13+
from utils import load_model_checkpoint, save_model
14+
15+
# SET PARAMS
16+
17+
input_fld = './checkpoints/trimmed/'
18+
model_file = 'TRIMMED_ds_model' #dont use extension e.g .json
19+
20+
21+
num_output = 1
22+
write_graph_def_ascii_flag = True
23+
prefix_output_node_names_of_final_network = 'output_node'
24+
output_graph_name = 'constant_graph_weights.pb'
25+
26+
## INIT
27+
28+
output_fld = "./Android/" + 'tensorflow_model/'
29+
if not os.path.isdir(output_fld):
30+
os.mkdir(output_fld)
31+
32+
33+
## LOAD KERAS MODEL AND RENAME OUTPUT
34+
35+
K.set_learning_phase(0)
36+
net_model = load_model_checkpoint(input_fld+model_file)
37+
38+
pred = [None]*num_output
39+
pred_node_names = [None]*num_output
40+
for i in range(num_output):
41+
pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)
42+
pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])
43+
print('output nodes names are: ', pred_node_names)
44+
45+
46+
# WHY [optional] write graph definition in asci??
47+
48+
sess = K.get_session()
49+
if write_graph_def_ascii_flag:
50+
f = 'only_the_graph_def.pb.ascii'
51+
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
52+
print('saved the graph definition in ascii format at: ', osp.join(output_fld, f))
53+
54+
# convert variables to constants and save
55+
56+
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
57+
graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False)
58+
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
59+
60+
61+
##### safety check graph
62+
pbfile = osp.join(output_fld, output_graph_name)
63+
64+
g = tf.GraphDef()
65+
g.ParseFromString(open(pbfile, "rb").read())
66+
print([n for n in g.node if n.name.find("input") != -1]) # same for output or any other node you want to make sure is ok
67+
print([n for n in g.node if n.name.find("out") != -1]) # same for output or any other node you want to make sure is ok
68+
69+
#the_input
70+
#output_node0
71+
72+
73+
##weird ops might not be defualt
74+
ops = set([n.op for n in g.node])
75+
print(ops)
76+
77+
78+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
_
2+
a
3+
b
4+
c
5+
d
6+
e
7+
f
8+
g
9+
h
10+
i
11+
j
12+
k
13+
l
14+
m
15+
n
16+
o
17+
p
18+
q
19+
r
20+
s
21+
t
22+
u
23+
v
24+
w
25+
x
26+
y
27+
z
28+
'
Binary file not shown.

0 commit comments

Comments
 (0)