1313# limitations under the License.
1414
1515# Imports
16- import os
17- from tqdm import tqdm
1816import collections
19- import openwakeword
20- import numpy as np
21- import scipy
17+ import os
2218import pickle
19+ from typing import List , Union
2320
21+ import numpy as np
22+ import scipy
2423from sklearn .linear_model import LogisticRegression
2524from sklearn .pipeline import make_pipeline
2625from sklearn .preprocessing import FunctionTransformer , StandardScaler
26+ from tqdm import tqdm
27+
28+ import openwakeword
2729
2830
2931# Define functions to prepare data for speaker dependent verifier model
@@ -112,8 +114,8 @@ def train_verifier_model(features: np.ndarray, labels: np.ndarray):
112114
113115
114116def train_custom_verifier (
115- positive_reference_clips : str ,
116- negative_reference_clips : str ,
117+ positive_reference_clips : List [ Union [ str , os . PathLike ]] ,
118+ negative_reference_clips : List [ Union [ str , os . PathLike ]] ,
117119 output_path : str ,
118120 model_name : str ,
119121 ** kwargs
@@ -123,11 +125,11 @@ def train_custom_verifier(
123125 from a single user.
124126
125127 Args:
126- positive_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
128+ positive_reference_clips (List[Union[ str, os.PathLike]] ): The path(s) to single-channel 16khz, 16-bit WAV files
127129 of the target wake word/phrase.
128- negative_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
130+ negative_reference_clips (List[Union[ str, os.PathLike]] ): The path(s) to single-channel 16khz, 16-bit WAV files
129131 of miscellaneous speech not containing the target wake word/phrase.
130- output_path (str): The location to save the trained verifier model (as a scikit-learn .joblib file)
132+ output_path (str): The location to save the trained verifier model (as a Python pickle file (.pkl) )
131133 model_name (str): The name or path of the trained openWakeWord model that the verifier model will be
132134 based on. If only a name, it must be one of the pre-trained models included in the
133135 openWakeWord release.
@@ -171,4 +173,5 @@ def train_custom_verifier(
171173
172174 # Save logistic regression model to specified output location
173175 print ("Done!" )
174- pickle .dump (lr_model , open (output_path , "wb" ))
176+ with open (output_path , "wb" ) as f :
177+ pickle .dump (lr_model , f )
0 commit comments