Skip to content

Commit

Permalink
feat(audio): integrate audio transfromers
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Dec 3, 2024
1 parent 36f1bf2 commit 3bedff8
Show file tree
Hide file tree
Showing 11 changed files with 490 additions and 7 deletions.
91 changes: 91 additions & 0 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,94 @@ def valid_types(prompts, media):
)

return prompts, media


class AudioSequenceGeneratorAdapter(SequenceGeneratorAdapter):
def __call__( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[str, Any],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""
Generate text from a prompt or list of prompts.
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
"""
prompts, media = self._validate_prompt_media_types(prompts, media)

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)

completions = self.model.generate(
prompts,
media,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)

return self._format(completions)

def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: List[Union[str, Any, List[Union[str, Any]]]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
prompts, media = self._validate_prompt_media_types(prompts, media)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(
prompts,
media,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)

@classmethod
def _validate_prompt_media_types(
cls,
prompts: Union[str, List[str]],
media: Union[str, Any, List[Union[str, Any]]],
) -> Union[Any, List[Any]]:
"""
Prepare media as np.ndarray and ensure for every prompt str there is one List[PIL.Image]
"""

def valid_types(prompts, media):
import numpy as np # type: ignore

if isinstance(prompts, list):
if not isinstance(media, list) or len(prompts) != len(media):
return False
for subprompt, submedia in zip(prompts, media):
if not isinstance(subprompt, str) or not all(
isinstance(m, np.ndarray) for m in submedia
):
return False
elif isinstance(prompts, str):
if not all(isinstance(m, np.ndarray) for m in media):
return False
return True

if not valid_types(prompts, media):
raise TypeError(
"Expected (prompts, media) to be of type "
"(str, List[np.ndarray])), or (List[str], List[List[np.ndarray]]) "
f"instead got prompts={prompts}, media={media}"
)

return prompts, media
11 changes: 10 additions & 1 deletion outlines/generate/cfg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from functools import singledispatch

from outlines.generate.api import (
AudioSequenceGeneratorAdapter,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import LlamaCpp, OpenAI, TransformersVision
from outlines.models import LlamaCpp, OpenAI, TransformersAudio, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -33,6 +34,14 @@ def cfg(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(TransformersAudio)
def cfg_audio(model, cfg_str: str, sampler: Sampler = multinomial()):
from outlines.processors import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(TransformersVision)
def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
from outlines.processors import CFGLogitsProcessor
Expand Down
12 changes: 11 additions & 1 deletion outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import (
AudioSequenceGeneratorAdapter,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import TransformersVision
from outlines.models import TransformersAudio, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand All @@ -22,6 +23,15 @@ def fsm(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(TransformersAudio)
def fsm_audio(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
from outlines.processors import GuideLogitsProcessor

guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(TransformersVision)
def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
from outlines.processors import GuideLogitsProcessor
Expand Down
15 changes: 14 additions & 1 deletion outlines/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from functools import singledispatch

from outlines.generate.api import (
AudioSequenceGeneratorAdapter,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import OpenAI, TransformersVision
from outlines.models import OpenAI, TransformersAudio, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -35,6 +36,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersAudio)
def regex_audio(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
def regex_vision(
model,
Expand Down
8 changes: 7 additions & 1 deletion outlines/generate/text.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from functools import singledispatch

from outlines.generate.api import (
AudioSequenceGeneratorAdapter,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import OpenAI, TransformersVision
from outlines.models import OpenAI, TransformersAudio, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -34,6 +35,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(TransformersAudio)
def text_audio(model, sampler: Sampler = multinomial()):
return AudioSequenceGeneratorAdapter(model, None, sampler)


@text.register(TransformersVision)
def text_vision(model, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, None, sampler)
Expand Down
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .transformers_audio import TransformersAudio, transformers_audio
from .transformers_vision import TransformersVision, transformers_vision
from .vllm import VLLM, vllm

Expand Down
136 changes: 136 additions & 0 deletions outlines/models/transformers_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models import Transformers

if TYPE_CHECKING:
from outlines.processors import OutlinesLogitsProcessor


class TransformersAudio(Transformers):
def __init__(self, model, tokenizer, processor):
super().__init__(model, tokenizer)
self.processor = processor

def generate( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[List[Any], List[List[Any]]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Union[str, List[str], List[List[str]]]:
"""Generate text using `transformers`.
Arguments
---------
prompts
A prompt or list of prompts.
media
A List[numpy.ndarray] or List[List[numpy.ndarray]]
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text
"""
inputs = self.processor(
text=prompts, audios=media, padding=True, return_tensors="pt"
).to(self.model.device)

generation_kwargs = self._get_generation_kwargs(
prompts,
generation_parameters,
logits_processor,
sampling_parameters,
)
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)

# if single str input and single sample per input, convert to a 1D output
if isinstance(prompts, str):
# Should always be true until NotImplementedError above is fixed
generated_ids = generated_ids.squeeze(0)

return self._decode_generation(generated_ids)

def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[Any, List[Any]], # TODO: docstring
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Iterator[Union[str, List[str]]]:
raise NotImplementedError


def transformers_audio(
model_name: str,
model_class,
device: Optional[str] = None,
model_kwargs: dict = {},
processor_kwargs: dict = {},
tokenizer_class=None,
processor_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
model_class
The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`.
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
processor_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the processor.
Returns
-------
A `TransformersModel` model instance.
"""
if processor_class is None or tokenizer_class is None:
try:
from transformers import AutoProcessor, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if processor_class is None:
processor_class = AutoProcessor

if device is not None:
model_kwargs["device_map"] = device

model = model_class.from_pretrained(model_name, **model_kwargs)

processor_kwargs.setdefault("padding_side", "left")
processor_kwargs.setdefault("pad_token", "[PAD]")
processor = processor_class.from_pretrained(model_name, **processor_kwargs)

if tokenizer_class is None:
if getattr(processor, "tokenizer", None):
tokenizer = processor.tokenizer
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs)
else:
tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs)

return TransformersAudio(model, tokenizer, processor)
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ test = [
"transformers",
"pillow",
"exllamav2",
"jax"
"jax",
"librosa",
]
serve = [
"vllm>=0.3.0",
Expand Down Expand Up @@ -147,6 +148,7 @@ module = [
"pycountry.*",
"airportsdata.*",
"outlines_core.*",
"librosa",
]
ignore_missing_imports = true

Expand Down
Loading

0 comments on commit 3bedff8

Please sign in to comment.