Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for auto-language-detection Whisper inference on HPU #1049

Open
Spycsh opened this issue Jun 7, 2024 · 1 comment
Open

Support for auto-language-detection Whisper inference on HPU #1049

Spycsh opened this issue Jun 7, 2024 · 1 comment

Comments

@Spycsh
Copy link
Contributor

Spycsh commented Jun 7, 2024

Feature request

Current Whisper inference works well with specified language. However, it does not support passing language=None, which can detect the language automatically. A RuntimeError is raised:

Traceback (most recent call last):
  File "/home/optimum-habana/examples/speech-recognition/asr1.py", line 93, in <module>
    text = asr.audio2text("sample.wav")
  File "/home/optimum-habana/examples/speech-recognition/asr1.py", line 62, in audio2text
    predicted_ids = self.model.generate(inputs, language=self.language)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py", line 540, in generate
    init_tokens = self._retrieve_init_tokens(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py", line 1177, in _retrieve_init_tokens
    if torch.unique(lang_ids).shape[0] > 1:
  File "/usr/local/lib/python3.10/dist-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 991, in _return_output
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 905, in _unique_impl
    output, inverse_indices, counts = torch._unique2(
RuntimeError: Argument passed to at() was not in the map.

Motivation

Here is the code to reproduce this error. If you specify the language by updating AudioSpeechRecognition(language=None, device="hpu") to AudioSpeechRecognition(language="english", device="hpu"), it works well.

import contextlib
import os
import time

import numpy as np
import torch
from datasets import Audio, Dataset
from pydub import AudioSegment
from transformers import WhisperForConditionalGeneration, WhisperProcessor, AutoModelForSpeechSeq2Seq

class AudioSpeechRecognition:
    """Convert audio to text."""

    def __init__(self, model_name_or_path="openai/whisper-small", language=None, device="cpu"):
        if device=="hpu":
            from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
            adapt_transformers_to_gaudi()
        self.device = device
        asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
        print("Downloading model: {}".format(asr_model_name_or_path))
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_name_or_path).to(self.device)
        self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path)
        self.model.eval()
        self.language = language

    def _audiosegment_to_librosawav(self, audiosegment):
        # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
        # This way is faster than librosa.load or HuggingFace Dataset wrapper
        channel_sounds = audiosegment.split_to_mono()[:1]  # only select the first channel
        samples = [s.get_array_of_samples() for s in channel_sounds]

        fp_arr = np.array(samples).T.astype(np.float32)
        fp_arr /= np.iinfo(samples[0].typecode).max
        fp_arr = fp_arr.reshape(-1)

        return fp_arr

    def audio2text(self, audio_path):
        """Convert audio to text.

        audio_path: the path to the input audio, e.g. ~/xxx.mp3
        """
        start = time.time()

        try:
            waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000)
            waveform = self._audiosegment_to_librosawav(waveform)
        except Exception as e:
            print(f"[ASR] audiosegment to librosa wave fail: {e}")
            audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
            waveform = audio_dataset[0]["audio"]["array"]

        inputs = self.processor.feature_extractor(
            waveform, return_tensors="pt", sampling_rate=16_000
        ).input_features.to(self.device)

        predicted_ids = self.model.generate(inputs, language=self.language)

        result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]

        print(f"generated text in {time.time() - start} seconds, and the result is: {result}")
        return result

if __name__ == "__main__":
    asr = AudioSpeechRecognition(language=None, device="hpu")
    import urllib.request

    urllib.request.urlretrieve(
        "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
        "sample.wav",
    )
    urllib.request.urlretrieve(
        "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample_2.wav",
        "sample2.wav",
    )
    urllib.request.urlretrieve(
        "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/welcome.wav",
        "welcome.wav",
    )
    urllib.request.urlretrieve(
        "https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/tacotron2_ljspeech_waveflow_samples_0.2/sentence_1.wav",
        "s1.wav",
    )
    urllib.request.urlretrieve(
        "https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
        "labxiaoxin.wav",
    )
    text = asr.audio2text("sample.wav")
    text = asr.audio2text("sample2.wav")
    text = asr.audio2text("sample.wav")
    text = asr.audio2text("welcome.wav")
    text = asr.audio2text("s1.wav")
    text = asr.audio2text("labxiaoxin.wav")
    text = asr.audio2text("s1.wav")
    text = asr.audio2text("labxiaoxin.wav")
    import os

    os.remove("sample.wav")
    print(text)

Your contribution

Please let me known if you have any plan on support this feature! If you do not have any plan, I can help to make a PR. My way is probably to insert a explicit HPU synchronization torch.hpu.synchronize() before accessing torch.unique in /usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py, which is tested to work but I'm not sure whether it is the proper way.

@regisss
Copy link
Collaborator

regisss commented Jul 10, 2024

Do you encounter the same issue with Transformers on GPU?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants