diff --git a/examples/custom_model.yml b/examples/custom_model.yml index a41d327..3e7dae1 100644 --- a/examples/custom_model.yml +++ b/examples/custom_model.yml @@ -32,6 +32,20 @@ piper_sample_generator_path: "./piper-sample-generator" # Sub-directories will be automatically created for train and test clips for both positive and negative examples output_dir: "./my_custom_model" +# The path to model used by piper_sample_generator +generator_model: "./piper-sample-generator/models/en_US-libritts_r-medium.pt" + +# Min phoneme count for piper_sample_generator +min_phoneme_count: null + +# Noise setting for trainning samples +noise_scales_train: + - 0.98 + +# Noise setting for testing samples +noise_scales_test: + - 1.0 + # The directories containing Room Impulse Response recordings rir_paths: - "./mit_rirs" diff --git a/openwakeword/train.py b/openwakeword/train.py index f564254..8c291e8 100755 --- a/openwakeword/train.py +++ b/openwakeword/train.py @@ -669,7 +669,9 @@ def convert_onnx_to_tflite(onnx_model_path, output_path): generate_samples( text=config["target_phrase"], max_samples=config["n_samples"]-n_current_samples, batch_size=config["tts_batch_size"], - noise_scales=[0.98], noise_scale_ws=[0.98], length_scales=[0.75, 1.0, 1.25], + model = config["generator_model"], + min_phoneme_count=config["min_phoneme_count"], + noise_scales=config["noise_scales_train"], noise_scale_ws=config["noise_scales_train"], length_scales=[0.75, 1.0, 1.25], output_dir=positive_train_output_dir, auto_reduce_batch_size=True, file_names=[uuid.uuid4().hex + ".wav" for i in range(config["n_samples"])] ) @@ -685,7 +687,9 @@ def convert_onnx_to_tflite(onnx_model_path, output_path): if n_current_samples <= 0.95*config["n_samples_val"]: generate_samples(text=config["target_phrase"], max_samples=config["n_samples_val"]-n_current_samples, batch_size=config["tts_batch_size"], - noise_scales=[1.0], noise_scale_ws=[1.0], length_scales=[0.75, 1.0, 1.25], + model = config["generator_model"], + min_phoneme_count=config["min_phoneme_count"], + noise_scales=config["noise_scales_test"], noise_scale_ws=config["noise_scales_test"], length_scales=[0.75, 1.0, 1.25], output_dir=positive_test_output_dir, auto_reduce_batch_size=True) torch.cuda.empty_cache() else: @@ -706,7 +710,11 @@ def convert_onnx_to_tflite(onnx_model_path, output_path): include_input_words=0.2)) generate_samples(text=adversarial_texts, max_samples=config["n_samples"]-n_current_samples, batch_size=config["tts_batch_size"]//7, - noise_scales=[0.98], noise_scale_ws=[0.98], length_scales=[0.75, 1.0, 1.25], + model = config["generator_model"], + min_phoneme_count=config["min_phoneme_count"], + noise_scales=config["noise_scales_train"], + noise_scale_ws=config["noise_scales_train"], + length_scales=[0.75, 1.0, 1.25], output_dir=negative_train_output_dir, auto_reduce_batch_size=True, file_names=[uuid.uuid4().hex + ".wav" for i in range(config["n_samples"])] ) @@ -729,7 +737,11 @@ def convert_onnx_to_tflite(onnx_model_path, output_path): include_input_words=0.2)) generate_samples(text=adversarial_texts, max_samples=config["n_samples_val"]-n_current_samples, batch_size=config["tts_batch_size"]//7, - noise_scales=[1.0], noise_scale_ws=[1.0], length_scales=[0.75, 1.0, 1.25], + model = config["generator_model"], + min_phoneme_count=config["min_phoneme_count"], + noise_scales=config["noise_scales_test"], + noise_scale_ws=config["noise_scales_test"], + length_scales=[0.75, 1.0, 1.25], output_dir=negative_test_output_dir, auto_reduce_batch_size=True) torch.cuda.empty_cache() else: