|
| 1 | +""" |
| 2 | +This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model. |
| 3 | +""" |
| 4 | + |
| 5 | +import torch |
| 6 | +import librosa |
| 7 | +import numpy as np |
| 8 | +import argparse |
| 9 | +from transformers import WavLMForSequenceClassification |
| 10 | + |
| 11 | + |
| 12 | +def feature_extract_simple( |
| 13 | + wav, |
| 14 | + sr=16_000, |
| 15 | + win_len=15.0, |
| 16 | + win_stride=15.0, |
| 17 | + do_normalize=False, |
| 18 | +): |
| 19 | + """simple feature extraction for wavLM |
| 20 | + Parameters |
| 21 | + ---------- |
| 22 | + wav : str or array-like |
| 23 | + path to the wav file, or array-like |
| 24 | + sr : int, optional |
| 25 | + sample rate, by default 16_000 |
| 26 | + win_len : float, optional |
| 27 | + window length, by default 15.0 |
| 28 | + win_stride : float, optional |
| 29 | + window stride, by default 15.0 |
| 30 | + do_normalize: bool, optional |
| 31 | + whether to normalize the input, by default False. |
| 32 | + Returns |
| 33 | + ------- |
| 34 | + np.ndarray |
| 35 | + batched input to wavLM |
| 36 | + """ |
| 37 | + if type(wav) == str: |
| 38 | + signal, _ = librosa.core.load(wav, sr=sr) |
| 39 | + else: |
| 40 | + try: |
| 41 | + signal = np.array(wav).squeeze() |
| 42 | + except Exception as e: |
| 43 | + print(e) |
| 44 | + raise RuntimeError |
| 45 | + batched_input = [] |
| 46 | + stride = int(win_stride * sr) |
| 47 | + l = int(win_len * sr) |
| 48 | + if len(signal) / sr > win_len: |
| 49 | + for i in range(0, len(signal), stride): |
| 50 | + if i + int(win_len * sr) > len(signal): |
| 51 | + # padding the last chunk to make it the same length as others |
| 52 | + chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) |
| 53 | + else: |
| 54 | + chunked = signal[i : i + l] |
| 55 | + if do_normalize: |
| 56 | + chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) |
| 57 | + batched_input.append(chunked) |
| 58 | + if i + int(win_len * sr) > len(signal): |
| 59 | + break |
| 60 | + else: |
| 61 | + if do_normalize: |
| 62 | + signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) |
| 63 | + batched_input.append(signal) |
| 64 | + return np.stack(batched_input) # [N, T] |
| 65 | + |
| 66 | + |
| 67 | +def infer(model, inputs): |
| 68 | + output = model(inputs) |
| 69 | + probs = torch.sigmoid(torch.Tensor(output.logits)) |
| 70 | + return probs |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + parser = argparse.ArgumentParser() |
| 75 | + parser.add_argument( |
| 76 | + "--audio_file", |
| 77 | + type=str, |
| 78 | + help="File to run inference", |
| 79 | + ) |
| 80 | + parser.add_argument( |
| 81 | + "--model_path", |
| 82 | + type=str, |
| 83 | + default="roblox/voice-safety-classifier", |
| 84 | + help="checkpoint file of model", |
| 85 | + ) |
| 86 | + args = parser.parse_args() |
| 87 | + labels_name_list = [ |
| 88 | + "Profanity", |
| 89 | + "DatingAndSexting", |
| 90 | + "Racist", |
| 91 | + "Bullying", |
| 92 | + "Other", |
| 93 | + "NoViolation", |
| 94 | + ] |
| 95 | + # Model is trained on only 16kHz audio |
| 96 | + audio, _ = librosa.core.load(args.audio_file, sr=16000) |
| 97 | + input_np = feature_extract_simple(audio, sr=16000) |
| 98 | + input_pt = torch.Tensor(input_np) |
| 99 | + model = WavLMForSequenceClassification.from_pretrained( |
| 100 | + args.model_path, num_labels=len(labels_name_list) |
| 101 | + ) |
| 102 | + probs = infer(model, input_pt) |
| 103 | + probs = probs.reshape(-1, 6).detach().tolist() |
| 104 | + print(f"Probabilities for {args.audio_file} is:") |
| 105 | + for chunk_idx in range(len(probs)): |
| 106 | + print(f"\nSegment {chunk_idx}:") |
| 107 | + for label_idx, label in enumerate(labels_name_list): |
| 108 | + print(f"{label} : {probs[chunk_idx][label_idx]}") |
0 commit comments