Skip to content

Commit

Permalink
Add support for calling mistral for titling (no public API yet) (#1168)
Browse files Browse the repository at this point in the history
Add support for calling mistral for titling, but without adding public
API yet.

- Add `generate_title_mistral` and `generate_category_mistral` titling
functions.
- change the internal `cluster_impl` so we can override the cluster
titling and category titling functions to use those
- At a later PR when we are happy with quality, we can switch to those
by default.
  • Loading branch information
dsmilkov authored Feb 8, 2024
1 parent 537a7ca commit 6cb4f72
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 264 deletions.
402 changes: 402 additions & 0 deletions lilac/data/cluster_titling.py

Large diffs are not rendered by default.

261 changes: 17 additions & 244 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
"""Clustering utilities."""
import functools
import gc
import itertools
import random
from typing import Any, Callable, Iterator, Optional, Union, cast
from typing import Callable, Iterator, Optional, Union, cast

import instructor
import modal
import numpy as np
from instructor.exceptions import IncompleteOutputException
from joblib import Parallel, delayed
from pydantic import (
BaseModel,
)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from tqdm import tqdm

from ..batch_utils import compress_docs, flatten_path_iter, group_by_sorted_key_iter
from ..batch_utils import compress_docs, flatten_path_iter
from ..dataset_format import DatasetFormatInputSelector
from ..embeddings.jina import JinaV2Small
from ..schema import (
Expand All @@ -36,22 +27,17 @@
)
from ..tasks import TaskId, TaskInfo, get_task_manager
from ..utils import DebugTimer, chunks, log
from .cluster_titling import (
compute_titles,
generate_category_openai,
generate_title_openai,
)
from .dataset import Dataset
from .dataset_utils import (
get_callable_name,
sparse_to_dense_compute,
)

_SHORTEN_LEN = 400
_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
_NUM_THREADS = 32
_NUM_RETRIES = 16
# OpenAI rate limits you on `max_tokens` so we ideally want to guess the right value. If ChatGPT
# fails to generate a title within the `max_tokens` limit, we will retry with a higher value.
_INITIAL_MAX_TOKENS = 50
_FINAL_MAX_TOKENS = 200

CLUSTER_ID = 'cluster_id'
CLUSTER_MEMBERSHIP_PROB = 'cluster_membership_prob'
CLUSTER_TITLE = 'cluster_title'
Expand All @@ -70,237 +56,22 @@
BATCH_SOFT_CLUSTER_NOISE = 1024


@functools.cache
def _openai_client() -> Any:
"""Get an OpenAI client."""
try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
'Please install it with `pip install openai`.'
)

# OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10
# mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set
# `max_retries` to 0 to disable internal retries so we handle retries ourselves.
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0))


def _snippet_to_prefix_and_suffix(text: str) -> str:
text = text.strip()
if len(text) <= _SHORTEN_LEN:
return text
prefix_len = _SHORTEN_LEN // 2
return text[:prefix_len] + '\n...\n' + text[-prefix_len:]


class Title(BaseModel):
"""A 4-5 word title for the group of related snippets."""

title: str


def summarize_request(ranked_docs: list[tuple[str, float]]) -> str:
"""Summarize a group of requests in a title of at most 5 words."""
# Get the top 5 documents.
docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_DOCS]]
texts = [f'BEGIN_SNIPPET\n{_snippet_to_prefix_and_suffix(doc)}\nEND_SNIPPET' for doc in docs]
input = '\n'.join(texts)
try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
'Please install it with `pip install openai`.'
)

@retry(
retry=retry_if_exception_type(
(
openai.RateLimitError,
openai.APITimeoutError,
openai.APIConnectionError,
openai.ConflictError,
openai.InternalServerError,
)
),
wait=wait_random_exponential(multiplier=0.5, max=60),
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
max_tokens = _INITIAL_MAX_TOKENS
while max_tokens <= _FINAL_MAX_TOKENS:
try:
title = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
response_model=Title,
temperature=0.0,
max_tokens=max_tokens,
messages=[
{
'role': 'system',
'content': (
'You are a world-class short title generator. Ignore the related snippets below '
'and generate a short title to describe their common theme. Some examples: '
'"YA book reviews", "Questions about South East Asia", "Translating English to '
'Polish", "Writing product descriptions", etc. Use descriptive words. If the '
"snippet's language is different than English, mention it in the title, e.g. "
'"Cooking questions in Spanish". Avoid vague words like "various", "assortment", '
'"comments", "discussion", etc.'
),
},
{'role': 'user', 'content': input},
],
)
return title.title
except IncompleteOutputException:
max_tokens += _INITIAL_MAX_TOKENS
log(f'Retrying with max_tokens={max_tokens}')
log(f'Could not generate a short title for input:\n{input}')
# We return a string instead of None, since None is emitted when the text column is sparse.
return 'FAILED_TO_TITLE'

return request_with_retries()


class Category(BaseModel):
"""A short category title."""

category: str


def generate_category(ranked_docs: list[tuple[str, float]]) -> str:
"""Summarize a list of titles in a category."""
# Get the top 5 documents.
docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_TITLES]]
input = '\n'.join(docs)
try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
'Please install it with `pip install openai`.'
)

@retry(
retry=retry_if_exception_type(
(
openai.RateLimitError,
openai.APITimeoutError,
openai.APIConnectionError,
openai.ConflictError,
openai.InternalServerError,
)
),
wait=wait_random_exponential(multiplier=0.5, max=60),
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
max_tokens = _INITIAL_MAX_TOKENS
while max_tokens <= _FINAL_MAX_TOKENS:
try:
category = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
response_model=Category,
temperature=0.0,
max_tokens=max_tokens,
messages=[
{
'role': 'system',
'content': (
'You are a world-class category labeler. Generate a short category name for the '
'provided titles. For example, given two titles "translating english to polish" '
'and "translating korean to english", generate "Translation".'
),
},
{'role': 'user', 'content': input},
],
)
return category.category
except IncompleteOutputException:
max_tokens += _INITIAL_MAX_TOKENS
log(f'Retrying with max_tokens={max_tokens}')
log(f'Could not generate a short category for input:\n{input}')
return 'FAILED_TO_GENERATE'

return request_with_retries()


def _compute_titles(
items: Iterator[Item],
text_column: str,
cluster_id_column: str,
membership_column: str,
topic_fn: TopicFn,
task_info: Optional[TaskInfo] = None,
) -> Iterator[str]:
def _compute_title(
sorted_docs: list[tuple[str, float]], group_size: int
) -> Optional[tuple[int, Optional[str]]]:
if not sorted_docs:
return group_size, None
return group_size, topic_fn(sorted_docs)

def _delayed_compute_all_titles() -> Iterator:
for group in group_by_sorted_key_iter(items, lambda x: x[cluster_id_column]):
sorted_docs: list[tuple[str, float]] = []

for item in group:
if not item:
continue

cluster_id = item.get(cluster_id_column, -1)
if cluster_id < 0:
continue

text = item.get(text_column)
if not text:
continue

membership_prob = item.get(membership_column, 0)
if membership_prob == 0:
continue

sorted_docs.append((text, membership_prob))

# Remove any duplicate texts in the group.
sorted_docs = list(set(sorted_docs))

# Shuffle the group to avoid biasing the topic function.
random.shuffle(sorted_docs)

# Sort the group by membership probability after shuffling so that we still choose high
# membership scores but they are still shuffled when the values are equal.
sorted_docs.sort(key=lambda text_score: text_score[1], reverse=True)

yield delayed(_compute_title)(sorted_docs, len(group))

parallel = Parallel(n_jobs=_NUM_THREADS, backend='threading', return_as='generator')
if task_info:
task_info.total_progress = 0
for group_size, title in parallel(_delayed_compute_all_titles()):
if task_info:
task_info.total_progress += group_size
for _ in range(group_size):
yield title


def cluster_impl(
dataset: Dataset,
input_fn_or_path: Union[Path, Callable[[Item], str], DatasetFormatInputSelector],
output_path: Optional[Path] = None,
min_cluster_size: int = MIN_CLUSTER_SIZE,
topic_fn: TopicFn = summarize_request,
topic_fn: Optional[TopicFn] = None,
category_fn: Optional[TopicFn] = None,
overwrite: bool = False,
use_garden: bool = False,
task_id: Optional[TaskId] = None,
recompute_titles: bool = False,
batch_size_titling: Optional[int] = None,
) -> None:
"""Compute clusters for a field of the dataset."""
topic_fn = topic_fn or generate_title_openai
category_fn = category_fn or generate_category_openai
task_manager = get_task_manager()
task_info: Optional[TaskInfo] = None
if task_id:
Expand Down Expand Up @@ -416,12 +187,13 @@ def cluster_documents(items: Iterator[Item]) -> Iterator[Item]:

def title_clusters(items: Iterator[Item]) -> Iterator[Item]:
items, items2 = itertools.tee(items)
titles = _compute_titles(
titles = compute_titles(
items,
text_column=TEXT_COLUMN,
cluster_id_column=CLUSTER_ID,
membership_column=CLUSTER_MEMBERSHIP_PROB,
topic_fn=topic_fn,
batch_size=batch_size_titling,
task_info=task_info,
)
for item, title in zip(items2, titles):
Expand Down Expand Up @@ -471,12 +243,13 @@ def cluster_titles(items: Iterator[Item]) -> Iterator[Item]:

def title_categories(items: Iterator[Item]) -> Iterator[Item]:
items, items2 = itertools.tee(items)
titles = _compute_titles(
titles = compute_titles(
items,
text_column=CLUSTER_TITLE,
cluster_id_column=CATEGORY_ID,
membership_column=CATEGORY_MEMBERSHIP_PROB,
topic_fn=generate_category,
topic_fn=category_fn,
batch_size=batch_size_titling,
task_info=task_info,
)
for item, title in zip(items2, titles):
Expand Down
Loading

0 comments on commit 6cb4f72

Please sign in to comment.