Skip to content

Commit 6d96480

Browse files
committed
Optimized self-confirm and added another argument
1 parent 35d9fe3 commit 6d96480

File tree

3 files changed

+109
-26
lines changed

3 files changed

+109
-26
lines changed

openwakeword/model.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
# Imports
16+
from concurrent.futures import ThreadPoolExecutor
1617
import numpy as np
1718
import openwakeword
1819
from openwakeword.utils import AudioFeatures, re_arg
@@ -42,6 +43,7 @@ def __init__(
4243
enable_speex_noise_suppression: bool = False,
4344
vad_threshold: float = 0,
4445
self_confirm: bool = False,
46+
self_confirm_ncpus: int = 1,
4547
custom_verifier_models: dict = {},
4648
custom_verifier_threshold: float = 0.1,
4749
inference_framework: str = "tflite",
@@ -71,6 +73,7 @@ def __init__(
7173
augmentation that can significantly reduce false detections, but also significantly increases
7274
the computational cost of running the model when used. See the `self_confirm` method for more
7375
details on how to leverage this functionality.
76+
self_confirm_ncpus (int): The number of CPU cores to use when running the self-confirmation model.
7477
custom_verifier_models (dict): A dictionary of paths to custom verifier models, where
7578
the keys are the model names (corresponding to the openwakeword.MODELS
7679
attribute) and the values are the filepaths of the
@@ -222,8 +225,19 @@ def onnx_predict(onnx_model, x):
222225
class_mapping_dicts=class_mapping_dicts,
223226
self_confirm=False,
224227
inference_framework=inference_framework,
225-
**kwargs
228+
ncpu=self_confirm_ncpus
226229
)
230+
self.confirmation_results = None
231+
self.self_confirm_ncpus = self_confirm_ncpus
232+
233+
# Create thread pool for self_confirm calling
234+
self.confirmation_executor = ThreadPoolExecutor(max_workers=1)
235+
236+
# Force thread pool initialization by submitting a dummy task
237+
# This avoids the first-call overhead later
238+
def _noop():
239+
pass
240+
self.confirmation_executor.submit(_noop).result()
227241

228242
# Create AudioFeatures object
229243
self.preprocessor = AudioFeatures(inference_framework=inference_framework, **kwargs)
@@ -244,6 +258,7 @@ def reset(self):
244258
when called too frequently."""
245259
self.prediction_buffer = defaultdict(partial(deque, maxlen=30))
246260
self.preprocessor.reset()
261+
self.confirmation_results = None
247262

248263
def predict(self, x: np.ndarray, patience: dict = {},
249264
threshold: dict = {}, debounce_time: float = 0.0, timing: bool = False):
@@ -401,7 +416,7 @@ def predict(self, x: np.ndarray, patience: dict = {},
401416
else:
402417
return predictions
403418

404-
def self_confirm(self, last_n_seconds: float = 1.5):
419+
def self_confirm(self, last_n_seconds: float = 1.5, background=False):
405420
"""
406421
Use the confirmation model to confirm the predictions from the main model. This is a form of
407422
test-time augmentation that can significantly reduce false detections, but significantly increases
@@ -420,42 +435,60 @@ def self_confirm(self, last_n_seconds: float = 1.5):
420435
last_n_seconds (float): The number of seconds of audio to use for confirmation.
421436
The default (1.5) should be sufficient for most use cases, but increase if your
422437
target wake-word/phrase is long, or decrease if short.
438+
background (bool): Whether to run the confirmation model in a background thread. If True, the results of
439+
the function will be returned asynchronously and stored in the
440+
`self.confirmation_results` attribute. Until the results are available, this attribute
441+
will be None.
423442
Returns:
424443
dict: A dictionary of scores between 0 and 1 for each model, representing the maximum
425444
score from the confirmation model over the last `last_n_seconds` seconds of audio.
445+
If background=True, returns None and stores results in self.confirmation_results when ready.
426446
"""
427447
# Check for self-confirm functionality
428448
if self.self_confirm_enabled is False:
429449
raise ValueError("The self-confirm functionality is not enabled for this model instance!")
430450

431451
# Check for at least two cores
432452
cpu_count = os.cpu_count()
433-
if cpu_count is None or cpu_count < 2:
453+
if (cpu_count is None or cpu_count < 2) and background is True:
434454
raise ValueError("The self-confirm functionality requires at least two CPU cores, as it uses threading.")
435455

436-
# Get the last n seconds of audio from the audio buffer of the main model, and get the features
437-
# with the self-confirmation model preprocessor
438-
n_samples = int(last_n_seconds*16000)
439-
if len(self.preprocessor.raw_data_buffer) < n_samples:
440-
raise ValueError("Not enough audio data has been processed to use the self-confirm functionality!")
441-
audio_data = np.array(self.preprocessor.raw_data_buffer)[-n_samples:]
442-
443-
# Reset the self-confirmation model, if it has been used before
444-
if self.confirmation_model.preprocessor.accumulated_samples == 0:
445-
self.confirmation_model.reset()
446-
447-
# Run model to get predictions
448-
step_size = 1280
449-
predictions = []
450-
for i in range(0, audio_data.shape[0]-step_size, step_size):
451-
predictions.append(self.confirmation_model.predict(audio_data[i:i+step_size]))
456+
# Define the function to run predictions
457+
def _run_confirmation_predictions():
458+
# Get the last n seconds of audio from the audio buffer of the main model, and get the features
459+
# with the self-confirmation model preprocessor
460+
n_samples = int(last_n_seconds*16000)
461+
if len(self.preprocessor.raw_data_buffer) < n_samples:
462+
raise ValueError("Not enough audio data has been processed to use the self-confirm functionality!")
463+
audio_data = np.fromiter(self.preprocessor.raw_data_buffer, dtype=np.int16)[-n_samples:]
452464

453-
predictions_dict = {}
454-
for mdl in predictions[0].keys():
455-
predictions_per_model = [p[mdl] for p in predictions]
456-
predictions_dict[mdl] = np.max(predictions_per_model)
465+
# Reset the self-confirmation model, if it has been used before
466+
if self.confirmation_model.preprocessor.accumulated_samples == 0:
467+
self.confirmation_model.reset()
457468

458-
return predictions_dict
469+
# Run model to get predictions
470+
step_size = 1280
471+
predictions = []
472+
for i in range(0, audio_data.shape[0]-step_size, step_size):
473+
predictions.append(self.confirmation_model.predict(audio_data[i:i+step_size]))
474+
475+
predictions_dict = {}
476+
for mdl in predictions[0].keys():
477+
predictions_per_model = [p[mdl] for p in predictions]
478+
predictions_dict[mdl] = np.max(predictions_per_model)
479+
480+
# Store results asynchronously
481+
self.confirmation_results = predictions_dict
482+
483+
# Run in background thread if requested
484+
if background:
485+
self.confirmation_results = None
486+
self.confirmation_executor.submit(_run_confirmation_predictions)
487+
return None
488+
else:
489+
# Run synchronously
490+
_run_confirmation_predictions()
491+
return self.confirmation_results
459492

460493
def predict_clip(self, clip: Union[str, np.ndarray], padding: int = 1, chunk_size=1280, **kwargs):
461494
"""Predict on an full audio clip, simulating streaming prediction.

openwakeword/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def tflite_embedding_predict(x):
166166
self.melspectrogram_max_len = 10*97 # 97 is the number of frames in 1 second of 16hz audio
167167
self.accumulated_samples = 0 # the samples added to the buffer since the audio preprocessor was last called
168168
self.raw_data_remainder = np.empty(0)
169-
self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))
169+
# self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))
170+
self.feature_buffer = np.load(os.path.join(pathlib.Path(__file__).parent.resolve(),
171+
"resources", "models", "feature_buffer_reset_data.npy"))
170172
self.feature_buffer_max_len = 120 # ~10 seconds of feature buffer history
171173

172174
def reset(self):
@@ -175,7 +177,9 @@ def reset(self):
175177
self.melspectrogram_buffer = np.ones((76, 32))
176178
self.accumulated_samples = 0
177179
self.raw_data_remainder = np.empty(0)
178-
self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))
180+
# self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))
181+
self.feature_buffer = np.load(os.path.join(pathlib.Path(__file__).parent.resolve(),
182+
"resources", "models", "feature_buffer_reset_data.npy"))
179183

180184
def _get_melspectrogram(self, x: Union[np.ndarray, List], melspec_transform: Callable = lambda x: x/10 + 2):
181185
"""

tests/test_self_confirm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import numpy as np
2020
import pytest
21+
import time
2122

2223

2324
# Tests
@@ -171,3 +172,48 @@ def test_self_confirm_multiclass_model(self):
171172
for model_name, score in predictions_dict.items():
172173
assert isinstance(score, (float, np.floating)), f"Score for {model_name} should be a float"
173174
assert 0 <= score <= 1, f"Score for {model_name} should be between 0 and 1, got {score}"
175+
176+
def test_self_confirm_background_true(self):
177+
"""Test self_confirm with background=True returns None and populates confirmation_results"""
178+
owwModel = openwakeword.Model(
179+
wakeword_models=[os.path.join("openwakeword", "resources", "models", "alexa_v0.1.onnx")],
180+
inference_framework="onnx",
181+
self_confirm=True
182+
)
183+
184+
# Feed in ~10 seconds of random data to fill the audio buffer
185+
chunk_size = 1280
186+
n_samples = 160000
187+
188+
for i in range(0, n_samples, chunk_size):
189+
random_audio = np.random.randint(-1000, 1000, chunk_size).astype(np.int16)
190+
owwModel.predict(random_audio)
191+
192+
# Run self-confirm in background mode
193+
result = owwModel.self_confirm(last_n_seconds=1.5, background=True)
194+
195+
# When background=True, should return None immediately
196+
assert result is None, "self_confirm with background=True should return None"
197+
198+
# confirmation_results should eventually be populated
199+
# Poll for results with a timeout (max 10 seconds)
200+
max_wait_time = 10
201+
start_time = time.time()
202+
while owwModel.confirmation_results is None and (time.time() - start_time) < max_wait_time:
203+
time.sleep(0.1)
204+
205+
# Verify that confirmation_results has been populated
206+
assert owwModel.confirmation_results is not None, "confirmation_results should be populated after background execution"
207+
208+
# Verify confirmation_results is properly formed
209+
predictions_dict = owwModel.confirmation_results
210+
assert isinstance(predictions_dict, dict), "confirmation_results should be a dictionary"
211+
212+
expected_models = list(owwModel.models.keys())
213+
assert len(predictions_dict) == len(expected_models), f"confirmation_results should have {len(expected_models)} key(s)"
214+
215+
for model_name in expected_models:
216+
assert model_name in predictions_dict, f"confirmation_results should contain key '{model_name}'"
217+
score = predictions_dict[model_name]
218+
assert isinstance(score, (float, np.floating)), f"Score for {model_name} should be a float"
219+
assert 0 <= score <= 1, f"Score for {model_name} should be between 0 and 1, got {score}"

0 commit comments

Comments
 (0)