-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.py
213 lines (172 loc) · 7.63 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""Exports YAMNet as: TF2 SavedModel, TF-Lite model, TF-JS model.
The exported models all accept as input:
- 1-d float32 Tensor of arbitrary shape containing an audio waveform
(assumed to be mono 16 kHz samples in the [-1, +1] range)
and return as output:
- a 2-d float32 Tensor of shape [num_frames, num_classes] containing
predicted class scores for each frame of audio extracted from the input.
- a 2-d float32 Tensor of shape [num_frames, embedding_size] containing
embeddings of each frame of audio.
- a 2-d float32 Tensor of shape [num_spectrogram_frames, num_mel_bins]
containing the log mel spectrogram of the entire waveform.
The SavedModels will also contain (as an asset) a class map CSV file that maps
class indices to AudioSet class names and Freebase MIDs. The path to the class
map is available as the 'class_map_path()' method of the restored model.
Requires pip-installing tensorflow_hub and tensorflowjs.
Usage:
export.py <path/to/YAMNet/weights-hdf-file> <path/to/output/directory>
and the various exports will be created in subdirectories of the output directory.
Assumes that it will be run in the yamnet source directory from where it loads
the class map. Skips an export if the corresponding directory already exists.
"""
import os
import sys
import tempfile
import time
import numpy as np
import tensorflow as tf
assert tf.version.VERSION >= '2.0.0', (
'Need at least TF 2.0, you have TF v{}'.format(tf.version.VERSION))
import tensorflow_hub as tfhub
from tensorflowjs.converters import tf_saved_model_conversion_v2 as tfjs_saved_model_converter
import params as yamnet_params
import yamnet
def log(msg):
print('\n=====\n{} | {}\n=====\n'.format(time.asctime(), msg), flush=True)
class YAMNet(tf.Module):
"''A TF2 Module wrapper around YAMNet."""
def __init__(self, weights_path, params):
super().__init__()
self._yamnet = yamnet.yamnet_frames_model(params)
self._yamnet.load_weights(weights_path)
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
@tf.function
def class_map_path(self):
return self._class_map_asset.asset_path
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
def __call__(self, waveform):
return self._yamnet(waveform)
def check_model(model_fn, class_map_path, params):
yamnet_classes = yamnet.class_names(class_map_path)
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
def clip_test(waveform, expected_class_name, top_n=10):
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
clip_predictions = np.mean(predictions, axis=0)
top_n_indices = np.argsort(clip_predictions)[-top_n:]
top_n_scores = clip_predictions[top_n_indices]
top_n_class_names = yamnet_classes[top_n_indices]
top_n_predictions = list(zip(top_n_class_names, top_n_scores))
assert expected_class_name in top_n_class_names, (
'Did not find expected class {} in top {} predictions: {}'.format(
expected_class_name, top_n, top_n_predictions))
clip_test(
waveform=np.zeros((int(3 * params.sample_rate),), dtype=np.float32),
expected_class_name='Silence')
np.random.seed(51773) # Ensure repeatability.
clip_test(
waveform=np.random.uniform(-1.0, +1.0,
(int(3 * params.sample_rate),)).astype(np.float32),
expected_class_name='White noise')
clip_test(
waveform=np.sin(2 * np.pi * 440 *
np.arange(0, 3, 1 / params.sample_rate), dtype=np.float32),
expected_class_name='Sine wave')
def make_tf2_export(weights_path, export_dir):
if os.path.exists(export_dir):
log('TF2 export already exists in {}, skipping TF2 export'.format(
export_dir))
return
# Create a TF2 Module wrapper around YAMNet.
log('Building and checking TF2 Module ...')
params = yamnet_params.Params()
yamnet = YAMNet(weights_path, params)
check_model(yamnet, yamnet.class_map_path(), params)
log('Done')
# Make TF2 SavedModel export.
log('Making TF2 SavedModel export ...')
tf.saved_model.save(yamnet, export_dir)
log('Done')
# Check export with TF-Hub in TF2.
log('Checking TF2 SavedModel export in TF2 ...')
model = tfhub.load(export_dir)
check_model(model, model.class_map_path(), params)
log('Done')
# Check export with TF-Hub in TF1.
log('Checking TF2 SavedModel export in TF1 ...')
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
model = tfhub.load(export_dir)
sess.run(tf.compat.v1.global_variables_initializer())
def run_model(waveform):
return sess.run(model(waveform))
check_model(run_model, model.class_map_path().eval(), params)
log('Done')
def make_tflite_export(weights_path, export_dir):
if os.path.exists(export_dir):
log('TF-Lite export already exists in {}, skipping TF-Lite export'.format(
export_dir))
return
# Create a TF-Lite compatible Module wrapper around YAMNet.
log('Building and checking TF-Lite Module ...')
params = yamnet_params.Params(tflite_compatible=True)
yamnet = YAMNet(weights_path, params)
check_model(yamnet, yamnet.class_map_path(), params)
log('Done')
# Make TF-Lite SavedModel export.
log('Making TF-Lite SavedModel export ...')
saved_model_dir = os.path.join(export_dir, 'saved_model')
os.makedirs(saved_model_dir)
tf.saved_model.save(yamnet, saved_model_dir)
log('Done')
# Check that the export can be loaded and works.
log('Checking TF-Lite SavedModel export in TF2 ...')
model = tf.saved_model.load(saved_model_dir)
check_model(model, model.class_map_path(), params)
log('Done')
# Make a TF-Lite model from the SavedModel.
log('Making TF-Lite model ...')
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = tflite_converter.convert()
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
log('Done')
# Check the TF-Lite export.
log('Checking TF-Lite model ...')
interpreter = tf.lite.Interpreter(tflite_model_path)
audio_input_index = interpreter.get_input_details()[0]['index']
scores_output_index = interpreter.get_output_details()[0]['index']
embeddings_output_index = interpreter.get_output_details()[1]['index']
spectrogram_output_index = interpreter.get_output_details()[2]['index']
def run_model(waveform):
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(audio_input_index, waveform)
interpreter.invoke()
return (interpreter.get_tensor(scores_output_index),
interpreter.get_tensor(embeddings_output_index),
interpreter.get_tensor(spectrogram_output_index))
check_model(run_model, 'yamnet_class_map.csv', params)
log('Done')
return saved_model_dir
def make_tfjs_export(tflite_saved_model_dir, export_dir):
if os.path.exists(export_dir):
log('TF-JS export already exists in {}, skipping TF-JS export'.format(
export_dir))
return
# Make a TF-JS model from the TF-Lite SavedModel export.
log('Making TF-JS model ...')
os.makedirs(export_dir)
tfjs_saved_model_converter.convert_tf_saved_model(
tflite_saved_model_dir, export_dir)
log('Done')
def main(args):
weights_path = args[0]
output_dir = args[1]
tf2_export_dir = os.path.join(output_dir, 'tf2')
make_tf2_export(weights_path, tf2_export_dir)
tflite_export_dir = os.path.join(output_dir, 'tflite')
tflite_saved_model_dir = make_tflite_export(weights_path, tflite_export_dir)
tfjs_export_dir = os.path.join(output_dir, 'tfjs')
make_tfjs_export(tflite_saved_model_dir, tfjs_export_dir)
if __name__ == '__main__':
main(sys.argv[1:])