Skip to content

Commit 270f4ba

Browse files
committed
Adjusted self-confirm functionality to always run in the background (required to work properly)
1 parent efd6fbb commit 270f4ba

File tree

2 files changed

+53
-65
lines changed

2 files changed

+53
-65
lines changed

openwakeword/model.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def predict(self, x: np.ndarray, patience: dict = {},
416416
else:
417417
return predictions
418418

419-
def self_confirm(self, last_n_seconds: float = 1.5, background=False):
419+
def self_confirm(self, last_n_seconds: float = 1.5, delay_time: float = 0.250):
420420
"""
421421
Use the confirmation model to confirm the predictions from the main model. This is a form of
422422
test-time augmentation that can significantly reduce false detections, but significantly increases
@@ -431,30 +431,38 @@ def self_confirm(self, last_n_seconds: float = 1.5, background=False):
431431
You are encouraged to experiment with the `last_n_seconds` argument to find the best balance
432432
between true-positive and false-positive detections for your use case.
433433
434+
This is a background task to not block the main model from processing audio, so the results
435+
of the confirmation model are stored in the `confirmation_results` class attribute once available.
436+
This is a dictionary with the same format as the output of the `predict` method, containing the
437+
maximum score from the confirmation model over the last `last_n_seconds` seconds of audio, giving a
438+
"confirmation" score for each model, indicating if a detection in the `last_n_seconds` seconds of audio
439+
was likely valid or not.
440+
434441
Args:
435442
last_n_seconds (float): The number of seconds of audio to use for confirmation.
436443
The default (1.5) should be sufficient for most use cases, but increase if your
437444
target wake-word/phrase is long, or decrease if short.
438-
background (bool): Whether to run the confirmation model in a background thread. If True, the results of
439-
the function will be returned asynchronously and stored in the
440-
`self.confirmation_results` attribute. Until the results are available, this attribute
441-
will be None.
445+
delay_time (float): The time (in seconds) to wait before running the confirmation model. This allows the
446+
main model to process enough audio after a detection to ensure that the confirmation
447+
model has enough audio context.
442448
Returns:
443-
dict: A dictionary of scores between 0 and 1 for each model, representing the maximum
444-
score from the confirmation model over the last `last_n_seconds` seconds of audio.
445-
If background=True, returns None and stores results in self.confirmation_results when ready.
449+
concurrent.futures.Future: A futures object representing the threading task running the confirmation model.
446450
"""
447451
# Check for self-confirm functionality
448452
if self.self_confirm_enabled is False:
449453
raise ValueError("The self-confirm functionality is not enabled for this model instance!")
450454

451455
# Check for at least two cores
452456
cpu_count = os.cpu_count()
453-
if (cpu_count is None or cpu_count < 2) and background is True:
457+
if (cpu_count is None or cpu_count < 2):
454458
raise ValueError("The self-confirm functionality requires at least two CPU cores, as it uses threading.")
455459

456460
# Define the function to run predictions
457461
def _run_confirmation_predictions():
462+
# Wait to allow main model to process audio
463+
if delay_time > 0:
464+
time.sleep(delay_time)
465+
458466
# Get the last n seconds of audio from the audio buffer of the main model, and get the features
459467
# with the self-confirmation model preprocessor
460468
n_samples = int(last_n_seconds*16000)
@@ -480,15 +488,11 @@ def _run_confirmation_predictions():
480488
# Store results asynchronously
481489
self.confirmation_results = predictions_dict
482490

483-
# Run in background thread if requested
484-
if background:
485-
self.confirmation_results = None
486-
self.confirmation_executor.submit(_run_confirmation_predictions)
487-
return None
488-
else:
489-
# Run synchronously
490-
_run_confirmation_predictions()
491-
return self.confirmation_results
491+
# Submit confirmation prediction task to thread pool
492+
self.confirmation_results = None # reset previous results
493+
future = self.confirmation_executor.submit(_run_confirmation_predictions)
494+
495+
return future
492496

493497
def predict_clip(self, clip: Union[str, np.ndarray], padding: int = 1, chunk_size=1280, **kwargs):
494498
"""Predict on an full audio clip, simulating streaming prediction.

tests/test_self_confirm.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@ def test_self_confirm_basic_functionality(self):
4242
owwModel.predict(random_audio)
4343

4444
# Run the self-confirm function
45-
predictions_dict = owwModel.self_confirm(last_n_seconds=1.5)
45+
owwModel.self_confirm(last_n_seconds=1.5)
46+
47+
# Poll for results with a timeout (max 10 seconds)
48+
max_wait_time = 10
49+
start_time = time.time()
50+
while owwModel.confirmation_results is None and (time.time() - start_time) < max_wait_time:
51+
time.sleep(0.1)
52+
predictions_dict = owwModel.confirmation_results
4653

4754
# Verify predictions_dict is properly formed
4855
assert isinstance(predictions_dict, dict), "predictions_dict should be a dictionary"
@@ -76,7 +83,14 @@ def test_self_confirm_with_multiple_models(self):
7683
owwModel.predict(random_audio)
7784

7885
# Run self-confirm
79-
predictions_dict = owwModel.self_confirm(last_n_seconds=1.5)
86+
owwModel.self_confirm(last_n_seconds=1.5)
87+
88+
# Poll for results with a timeout (max 10 seconds)
89+
max_wait_time = 10
90+
start_time = time.time()
91+
while owwModel.confirmation_results is None and (time.time() - start_time) < max_wait_time:
92+
time.sleep(0.1)
93+
predictions_dict = owwModel.confirmation_results
8094

8195
# Verify all models have predictions
8296
assert len(predictions_dict) >= 2, "predictions_dict should have at least 2 models"
@@ -120,7 +134,8 @@ def test_self_confirm_insufficient_audio_data(self):
120134

121135
# Attempting to call self_confirm should raise ValueError
122136
with pytest.raises(ValueError, match="Not enough audio data"):
123-
owwModel.self_confirm(last_n_seconds=1.5)
137+
future = owwModel.self_confirm(last_n_seconds=1.5)
138+
future.result()
124139

125140
def test_self_confirm_with_tflite_models(self):
126141
"""Test self_confirm with tflite inference framework"""
@@ -139,7 +154,14 @@ def test_self_confirm_with_tflite_models(self):
139154
owwModel.predict(random_audio)
140155

141156
# Run self-confirm
142-
predictions_dict = owwModel.self_confirm(last_n_seconds=1.5)
157+
owwModel.self_confirm(last_n_seconds=1.5)
158+
159+
# Poll for results with a timeout (max 10 seconds)
160+
max_wait_time = 10
161+
start_time = time.time()
162+
while owwModel.confirmation_results is None and (time.time() - start_time) < max_wait_time:
163+
time.sleep(0.1)
164+
predictions_dict = owwModel.confirmation_results
143165

144166
# Verify predictions_dict is properly formed
145167
assert isinstance(predictions_dict, dict)
@@ -163,57 +185,19 @@ def test_self_confirm_multiclass_model(self):
163185
owwModel.predict(random_audio)
164186

165187
# Run self-confirm
166-
predictions_dict = owwModel.self_confirm(last_n_seconds=1.5)
167-
168-
# Verify predictions_dict is properly formed
169-
assert isinstance(predictions_dict, dict)
170-
assert len(predictions_dict) > 0, "predictions_dict should not be empty"
171-
172-
for model_name, score in predictions_dict.items():
173-
assert isinstance(score, (float, np.floating)), f"Score for {model_name} should be a float"
174-
assert 0 <= score <= 1, f"Score for {model_name} should be between 0 and 1, got {score}"
175-
176-
def test_self_confirm_background_true(self):
177-
"""Test self_confirm with background=True returns None and populates confirmation_results"""
178-
owwModel = openwakeword.Model(
179-
wakeword_models=[os.path.join("openwakeword", "resources", "models", "alexa_v0.1.onnx")],
180-
inference_framework="onnx",
181-
self_confirm=True
182-
)
188+
owwModel.self_confirm(last_n_seconds=1.5)
183189

184-
# Feed in ~10 seconds of random data to fill the audio buffer
185-
chunk_size = 1280
186-
n_samples = 160000
187-
188-
for i in range(0, n_samples, chunk_size):
189-
random_audio = np.random.randint(-1000, 1000, chunk_size).astype(np.int16)
190-
owwModel.predict(random_audio)
191-
192-
# Run self-confirm in background mode
193-
result = owwModel.self_confirm(last_n_seconds=1.5, background=True)
194-
195-
# When background=True, should return None immediately
196-
assert result is None, "self_confirm with background=True should return None"
197-
198-
# confirmation_results should eventually be populated
199190
# Poll for results with a timeout (max 10 seconds)
200191
max_wait_time = 10
201192
start_time = time.time()
202193
while owwModel.confirmation_results is None and (time.time() - start_time) < max_wait_time:
203194
time.sleep(0.1)
204-
205-
# Verify that confirmation_results has been populated
206-
assert owwModel.confirmation_results is not None, "confirmation_results should be populated after background execution"
207-
208-
# Verify confirmation_results is properly formed
209195
predictions_dict = owwModel.confirmation_results
210-
assert isinstance(predictions_dict, dict), "confirmation_results should be a dictionary"
211196

212-
expected_models = list(owwModel.models.keys())
213-
assert len(predictions_dict) == len(expected_models), f"confirmation_results should have {len(expected_models)} key(s)"
197+
# Verify predictions_dict is properly formed
198+
assert isinstance(predictions_dict, dict)
199+
assert len(predictions_dict) > 0, "predictions_dict should not be empty"
214200

215-
for model_name in expected_models:
216-
assert model_name in predictions_dict, f"confirmation_results should contain key '{model_name}'"
217-
score = predictions_dict[model_name]
201+
for model_name, score in predictions_dict.items():
218202
assert isinstance(score, (float, np.floating)), f"Score for {model_name} should be a float"
219203
assert 0 <= score <= 1, f"Score for {model_name} should be between 0 and 1, got {score}"

0 commit comments

Comments
 (0)