Skip to content

Commit 8322a96

Browse files
committed
Fixed bug with custom verifier model loading/prediction and incremented versioning accordingly
1 parent a0311f2 commit 8322a96

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

openwakeword/model.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
class_mapping_dicts: List[dict] = [],
4040
enable_speex_noise_suppression: bool = False,
4141
vad_threshold: float = 0,
42-
custom_verifier_models: Union[bool, dict] = False,
42+
custom_verifier_models: dict = {},
4343
custom_verifier_threshold: float = 0.1,
4444
**kwargs
4545
):
@@ -112,6 +112,14 @@ def __init__(
112112
if custom_verifier_models.get(mdl_name, False):
113113
self.custom_verifier_models[mdl_name] = pickle.load(open(custom_verifier_models[mdl_name], 'rb'))
114114

115+
if len(self.custom_verifier_models.keys()) < len(custom_verifier_models.keys()):
116+
raise ValueError(
117+
"Custom verifier models were provided, but some were not matched with a base model!"
118+
" Make sure that the keys provided in the `custom_verifier_models` dictionary argument"
119+
" exactly match that of the `.models` attribute of an instantiated openWakeWord Model object"
120+
" that has the same base models but doesn't have custom verifier models."
121+
)
122+
115123
# Create buffer to store frame predictions
116124
self.prediction_buffer: DefaultDict[str, deque] = defaultdict(partial(deque, maxlen=30))
117125

@@ -208,10 +216,11 @@ def predict(self, x: np.ndarray, patience: dict = {}, threshold: dict = {}, timi
208216
for cls in predictions.keys():
209217
if predictions[cls] >= self.custom_verifier_threshold:
210218
parent_model = self.get_parent_model_from_label(cls)
211-
verifier_prediction = self.custom_verifier_models[parent_model].predict_proba(
212-
self.preprocessor.get_features(self.model_inputs[mdl])
213-
)[0][-1]
214-
predictions[cls] = verifier_prediction
219+
if self.custom_verifier_models.get(parent_model, False):
220+
verifier_prediction = self.custom_verifier_models[parent_model].predict_proba(
221+
self.preprocessor.get_features(self.model_inputs[mdl])
222+
)[0][-1]
223+
predictions[cls] = verifier_prediction
215224

216225
# Update prediction buffer, and zero predictions for first 5 frames during model initialization
217226
for cls in predictions.keys():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ testpaths = [
1212

1313
[project]
1414
name = "openwakeword"
15-
version = "0.3.0"
15+
version = "0.3.1"
1616
authors = [
1717
{ name="David Scripka", email="[email protected]" },
1818
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_additional_requires():
2626

2727
setuptools.setup(
2828
name="openwakeword",
29-
version="0.3.0",
29+
version="0.3.1",
3030
install_requires=['onnxruntime>=1.10.0,<2', 'tqdm>=4.0,<5.0', 'scipy>=1.3,<2', 'scikit-learn>=1,<2'],
3131
extras_require={
3232
'test': [

tests/test_custom_verifier_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,15 @@ def test_train_verifier_model(self):
7272
model_name=os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")
7373
)
7474

75-
# Load model with verifier model
75+
with pytest.raises(ValueError):
76+
# Load model with verifier model incorrectly to catch ValueError
77+
owwModel = openwakeword.Model(
78+
wakeword_model_paths=[os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")],
79+
custom_verifier_models={"bad_key": os.path.join(tmp_dir, "verifier_model.pkl")},
80+
custom_verifier_threshold=0.3,
81+
)
82+
83+
# Load model with verifier model incorrectly to catch ValueError
7684
owwModel = openwakeword.Model(
7785
wakeword_model_paths=[os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")],
7886
custom_verifier_models={"hey_mycroft_v0.1": os.path.join(tmp_dir, "verifier_model.pkl")},

0 commit comments

Comments
 (0)