@@ -39,7 +39,7 @@ def __init__(
3939 class_mapping_dicts : List [dict ] = [],
4040 enable_speex_noise_suppression : bool = False ,
4141 vad_threshold : float = 0 ,
42- custom_verifier_models : Union [ bool , dict ] = False ,
42+ custom_verifier_models : dict = {} ,
4343 custom_verifier_threshold : float = 0.1 ,
4444 ** kwargs
4545 ):
@@ -112,6 +112,14 @@ def __init__(
112112 if custom_verifier_models .get (mdl_name , False ):
113113 self .custom_verifier_models [mdl_name ] = pickle .load (open (custom_verifier_models [mdl_name ], 'rb' ))
114114
115+ if len (self .custom_verifier_models .keys ()) < len (custom_verifier_models .keys ()):
116+ raise ValueError (
117+ "Custom verifier models were provided, but some were not matched with a base model!"
118+ " Make sure that the keys provided in the `custom_verifier_models` dictionary argument"
119+ " exactly match that of the `.models` attribute of an instantiated openWakeWord Model object"
120+ " that has the same base models but doesn't have custom verifier models."
121+ )
122+
115123 # Create buffer to store frame predictions
116124 self .prediction_buffer : DefaultDict [str , deque ] = defaultdict (partial (deque , maxlen = 30 ))
117125
@@ -208,10 +216,11 @@ def predict(self, x: np.ndarray, patience: dict = {}, threshold: dict = {}, timi
208216 for cls in predictions .keys ():
209217 if predictions [cls ] >= self .custom_verifier_threshold :
210218 parent_model = self .get_parent_model_from_label (cls )
211- verifier_prediction = self .custom_verifier_models [parent_model ].predict_proba (
212- self .preprocessor .get_features (self .model_inputs [mdl ])
213- )[0 ][- 1 ]
214- predictions [cls ] = verifier_prediction
219+ if self .custom_verifier_models .get (parent_model , False ):
220+ verifier_prediction = self .custom_verifier_models [parent_model ].predict_proba (
221+ self .preprocessor .get_features (self .model_inputs [mdl ])
222+ )[0 ][- 1 ]
223+ predictions [cls ] = verifier_prediction
215224
216225 # Update prediction buffer, and zero predictions for first 5 frames during model initialization
217226 for cls in predictions .keys ():
0 commit comments