diff --git a/lilac/data/cluster_titling.py b/lilac/data/cluster_titling.py new file mode 100644 index 00000000..47c22b32 --- /dev/null +++ b/lilac/data/cluster_titling.py @@ -0,0 +1,402 @@ +"""Functions for generating titles and categories for clusters of documents.""" + +import functools +import random +from typing import Any, Iterator, Optional, Sequence, cast + +import instructor +import modal +from instructor.exceptions import IncompleteOutputException +from joblib import Parallel, delayed +from pydantic import ( + BaseModel, +) +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential + +from ..batch_utils import group_by_sorted_key_iter +from ..schema import ( + Item, +) +from ..signal import ( + TopicFn, + TopicFnBatched, + TopicFnNoBatch, +) +from ..tasks import TaskInfo +from ..utils import chunks, log + +_TOP_K_CENTRAL_DOCS = 7 +_TOP_K_CENTRAL_TITLES = 20 +_NUM_THREADS = 32 +_NUM_RETRIES = 16 +# OpenAI rate limits you on `max_tokens` so we ideally want to guess the right value. If ChatGPT +# fails to generate a title within the `max_tokens` limit, we will retry with a higher value. +_OPENAI_INITIAL_MAX_TOKENS = 50 +_OPENAI_FINAL_MAX_TOKENS = 200 + +TITLE_SYSTEM_PROMPT = ( + 'You are a world-class short title generator. Ignore the related snippets below ' + 'and generate a short title (5 words maximum) to describe their common theme. Some examples: ' + '"YA book reviews", "Questions about South East Asia", "Translating English to ' + 'Polish", "Writing product descriptions", etc. If the ' + "snippet's language is different than English, mention it in the title, e.g. " + '"Recipes in Spanish". Avoid vague words like "various", "assortment", ' + '"comments", "discussion", "requests", etc.' +) +EXAMPLE_MATH_SNIPPETS = [ + ( + 'Explain each computation step in the evaluation of 90504690 / 37364. Exclude words; show only ' + 'the math.' + ), + 'What does 3-9030914617332 yield? Only respond with math and no words.', + 'Provide a step-by-step calculation for 224 * 276429. Exclude words; show only the math.', +] + +EXAMPLE_SNIPPETS = '\n'.join( + [f'BEGIN_SNIPPET\n{doc}\nEND_SNIPPET' for doc in EXAMPLE_MATH_SNIPPETS] +) +EXAMPLE_TITLE = 'Mathematical Calculations' + +CATEGORY_SYSTEM_PROMPT = ( + 'You are a world-class category generator. Generate a short category name (one or two words ' + 'long) for the provided titles. Do not use parentheses and do not generate alternative names.' +) +CATEGORY_EXAMPLE_TITLES = '\n'.join( + [ + 'Graph Theory and Tree Counting Problems', + 'Mathematical Problem Solving and Calculations', + 'Pizza Mathematics and Optimization', + 'Mathematical Equations and Operations', + ] +) +EXAMPLE_CATEGORY = 'Category: Mathematics' + +_SHORTEN_LEN = 400 + + +def get_titling_snippet(text: str) -> str: + """Shorten the text to a snippet for titling.""" + text = text.strip() + if len(text) <= _SHORTEN_LEN: + return text + prefix_len = _SHORTEN_LEN // 2 + return text[:prefix_len] + '\n...\n' + text[-prefix_len:] + + +class ChatMessage(BaseModel): + """Message in a conversation.""" + + role: str + content: str + + +class SamplingParams(BaseModel): + """Sampling parameters for the mistral model.""" + + temperature: float = 0.0 + top_p: float = 1.0 + max_tokens: int = 50 + stop: Optional[str] = None + spaces_between_special_tokens: bool = False + + +class MistralInstructRequest(BaseModel): + """Request to embed a list of documents.""" + + chats: list[list[ChatMessage]] + sampling_params: SamplingParams = SamplingParams() + + +class MistralInstructResponse(BaseModel): + """Response from the Mistral model.""" + + outputs: list[str] + + +def generate_category_mistral(batch_titles: list[list[tuple[str, float]]]) -> list[str]: + """Summarize a group of titles into a category.""" + remote_fn = modal.Function.lookup('mistral-7b', 'Instruct.generate').remote + request = MistralInstructRequest(chats=[], sampling_params=SamplingParams(stop='\n')) + for ranked_titles in batch_titles: + # Get the top 5 titles. + titles = [title for title, _ in ranked_titles[:_TOP_K_CENTRAL_DOCS]] + snippets = '\n'.join(titles) + messages: list[ChatMessage] = [ + ChatMessage(role='system', content=CATEGORY_SYSTEM_PROMPT), + ChatMessage(role='user', content=CATEGORY_EXAMPLE_TITLES), + ChatMessage(role='assistant', content=EXAMPLE_CATEGORY), + ChatMessage(role='user', content=snippets), + ] + request.chats.append(messages) + + category_prefix = 'category: ' + + # TODO(smilkov): Add retry logic. + def request_with_retries() -> list[str]: + response_dict = remote_fn(request.model_dump()) + response = MistralInstructResponse.model_validate(response_dict) + result: list[str] = [] + for title in response.outputs: + title = title.strip() + if title.lower().startswith(category_prefix): + title = title[len(category_prefix) :] + result.append(title) + return result + + return request_with_retries() + + +def generate_title_mistral(batch_docs: list[list[tuple[str, float]]]) -> list[str]: + """Summarize a group of requests in a title of at most 5 words.""" + remote_fn = modal.Function.lookup('mistral-7b', 'Instruct.generate').remote + request = MistralInstructRequest(chats=[], sampling_params=SamplingParams(stop='\n')) + for ranked_docs in batch_docs: + # Get the top 5 documents. + docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_DOCS]] + snippets = '\n'.join( + [f'BEGIN_SNIPPET\n{get_titling_snippet(doc)}\nEND_SNIPPET' for doc in docs] + ) + messages: list[ChatMessage] = [ + ChatMessage(role='system', content=TITLE_SYSTEM_PROMPT), + ChatMessage(role='user', content=EXAMPLE_SNIPPETS), + ChatMessage(role='assistant', content=EXAMPLE_TITLE), + ChatMessage(role='user', content=snippets), + ] + request.chats.append(messages) + + title_prefix = 'title: ' + + # TODO(smilkov): Add retry logic. + def request_with_retries() -> list[str]: + response_dict = remote_fn(request.model_dump()) + response = MistralInstructResponse.model_validate(response_dict) + result: list[str] = [] + for title in response.outputs: + title = title.strip() + if title.lower().startswith(title_prefix): + title = title[len(title_prefix) :] + result.append(title) + return result + + return request_with_retries() + + +@functools.cache +def _openai_client() -> Any: + """Get an OpenAI client.""" + try: + import openai + + except ImportError: + raise ImportError( + 'Could not import the "openai" python package. ' + 'Please install it with `pip install openai`.' + ) + + # OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10 + # mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set + # `max_retries` to 0 to disable internal retries so we handle retries ourselves. + return instructor.patch(openai.OpenAI(timeout=7, max_retries=0)) + + +class Title(BaseModel): + """A 4-5 word title for the group of related snippets.""" + + title: str + + +def generate_title_openai(ranked_docs: list[tuple[str, float]]) -> str: + """Generate a short title for a set of documents using OpenAI.""" + # Get the top 5 documents. + docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_DOCS]] + texts = [f'BEGIN_SNIPPET\n{get_titling_snippet(doc)}\nEND_SNIPPET' for doc in docs] + input = '\n'.join(texts) + try: + import openai + + except ImportError: + raise ImportError( + 'Could not import the "openai" python package. ' + 'Please install it with `pip install openai`.' + ) + + @retry( + retry=retry_if_exception_type( + ( + openai.RateLimitError, + openai.APITimeoutError, + openai.APIConnectionError, + openai.ConflictError, + openai.InternalServerError, + ) + ), + wait=wait_random_exponential(multiplier=0.5, max=60), + stop=stop_after_attempt(_NUM_RETRIES), + ) + def request_with_retries() -> str: + max_tokens = _OPENAI_INITIAL_MAX_TOKENS + while max_tokens <= _OPENAI_FINAL_MAX_TOKENS: + try: + title = _openai_client().chat.completions.create( + model='gpt-3.5-turbo-1106', + response_model=Title, + temperature=0.0, + max_tokens=max_tokens, + messages=[ + { + 'role': 'system', + 'content': TITLE_SYSTEM_PROMPT, + }, + {'role': 'user', 'content': input}, + ], + ) + return title.title + except IncompleteOutputException: + max_tokens += _OPENAI_INITIAL_MAX_TOKENS + log(f'Retrying with max_tokens={max_tokens}') + log(f'Could not generate a short title for input:\n{input}') + # We return a string instead of None, since None is emitted when the text column is sparse. + return 'FAILED_TO_TITLE' + + return request_with_retries() + + +class Category(BaseModel): + """A short category title.""" + + category: str + + +def generate_category_openai(ranked_docs: list[tuple[str, float]]) -> str: + """Summarize a list of titles in a category.""" + # Get the top 5 documents. + docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_TITLES]] + input = '\n'.join(docs) + try: + import openai + + except ImportError: + raise ImportError( + 'Could not import the "openai" python package. ' + 'Please install it with `pip install openai`.' + ) + + @retry( + retry=retry_if_exception_type( + ( + openai.RateLimitError, + openai.APITimeoutError, + openai.APIConnectionError, + openai.ConflictError, + openai.InternalServerError, + ) + ), + wait=wait_random_exponential(multiplier=0.5, max=60), + stop=stop_after_attempt(_NUM_RETRIES), + ) + def request_with_retries() -> str: + max_tokens = _OPENAI_INITIAL_MAX_TOKENS + while max_tokens <= _OPENAI_FINAL_MAX_TOKENS: + try: + category = _openai_client().chat.completions.create( + model='gpt-3.5-turbo-1106', + response_model=Category, + temperature=0.0, + max_tokens=max_tokens, + messages=[ + { + 'role': 'system', + 'content': ( + 'You are a world-class category labeler. Generate a short category name for the ' + 'provided titles. For example, given two titles "translating english to polish" ' + 'and "translating korean to english", generate "Translation".' + ), + }, + {'role': 'user', 'content': input}, + ], + ) + return category.category + except IncompleteOutputException: + max_tokens += _OPENAI_INITIAL_MAX_TOKENS + log(f'Retrying with max_tokens={max_tokens}') + log(f'Could not generate a short category for input:\n{input}') + return 'FAILED_TO_GENERATE' + + return request_with_retries() + + +def compute_titles( + items: Iterator[Item], + text_column: str, + cluster_id_column: str, + membership_column: str, + topic_fn: TopicFn, + batch_size: Optional[int] = None, + task_info: Optional[TaskInfo] = None, +) -> Iterator[str]: + """Compute titles for clusters of documents.""" + + def _compute_title( + batch_docs: list[list[tuple[str, float]]], group_size: list[int] + ) -> list[tuple[int, Optional[str]]]: + if batch_size is None: + topic_fn_no_batch = cast(TopicFnNoBatch, topic_fn) + topics: Sequence[Optional[str]] + if batch_docs and batch_docs[0]: + topics = [topic_fn_no_batch(batch_docs[0])] + else: + topics = [None] + else: + topic_fn_batched = cast(TopicFnBatched, topic_fn) + topics = topic_fn_batched(batch_docs) + return [(group_size, topic) for group_size, topic in zip(group_size, topics)] + + def _delayed_compute_all_titles() -> Iterator: + clusters = group_by_sorted_key_iter(items, lambda x: x[cluster_id_column]) + for batch_clusters in chunks(clusters, batch_size or 1): + cluster_sizes: list[int] = [] + batch_docs: list[list[tuple[str, float]]] = [] + for cluster in batch_clusters: + sorted_docs: list[tuple[str, float]] = [] + + for item in cluster: + if not item: + continue + + cluster_id = item.get(cluster_id_column, -1) + if cluster_id < 0: + continue + + text = item.get(text_column) + if not text: + continue + + membership_prob = item.get(membership_column, 0) + if membership_prob == 0: + continue + + sorted_docs.append((text, membership_prob)) + + # Remove any duplicate texts in the cluster. + sorted_docs = list(set(sorted_docs)) + + # Shuffle the cluster to avoid biasing the topic function. + random.shuffle(sorted_docs) + + # Sort the cluster by membership probability after shuffling so that we still choose high + # membership scores but they are still shuffled when the values are equal. + sorted_docs.sort(key=lambda text_score: text_score[1], reverse=True) + cluster_sizes.append(len(cluster)) + batch_docs.append(sorted_docs) + + yield delayed(_compute_title)(batch_docs, cluster_sizes) + + parallel = Parallel(n_jobs=_NUM_THREADS, backend='threading', return_as='generator') + if task_info: + task_info.total_progress = 0 + for batch_result in parallel(_delayed_compute_all_titles()): + for group_size, title in batch_result: + if task_info: + task_info.total_progress += group_size + for _ in range(group_size): + yield title diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index a12b9011..a9762372 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -1,22 +1,13 @@ """Clustering utilities.""" -import functools import gc import itertools -import random -from typing import Any, Callable, Iterator, Optional, Union, cast +from typing import Callable, Iterator, Optional, Union, cast -import instructor import modal import numpy as np -from instructor.exceptions import IncompleteOutputException -from joblib import Parallel, delayed -from pydantic import ( - BaseModel, -) -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential from tqdm import tqdm -from ..batch_utils import compress_docs, flatten_path_iter, group_by_sorted_key_iter +from ..batch_utils import compress_docs, flatten_path_iter from ..dataset_format import DatasetFormatInputSelector from ..embeddings.jina import JinaV2Small from ..schema import ( @@ -36,22 +27,17 @@ ) from ..tasks import TaskId, TaskInfo, get_task_manager from ..utils import DebugTimer, chunks, log +from .cluster_titling import ( + compute_titles, + generate_category_openai, + generate_title_openai, +) from .dataset import Dataset from .dataset_utils import ( get_callable_name, sparse_to_dense_compute, ) -_SHORTEN_LEN = 400 -_TOP_K_CENTRAL_DOCS = 7 -_TOP_K_CENTRAL_TITLES = 20 -_NUM_THREADS = 32 -_NUM_RETRIES = 16 -# OpenAI rate limits you on `max_tokens` so we ideally want to guess the right value. If ChatGPT -# fails to generate a title within the `max_tokens` limit, we will retry with a higher value. -_INITIAL_MAX_TOKENS = 50 -_FINAL_MAX_TOKENS = 200 - CLUSTER_ID = 'cluster_id' CLUSTER_MEMBERSHIP_PROB = 'cluster_membership_prob' CLUSTER_TITLE = 'cluster_title' @@ -70,237 +56,22 @@ BATCH_SOFT_CLUSTER_NOISE = 1024 -@functools.cache -def _openai_client() -> Any: - """Get an OpenAI client.""" - try: - import openai - - except ImportError: - raise ImportError( - 'Could not import the "openai" python package. ' - 'Please install it with `pip install openai`.' - ) - - # OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10 - # mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set - # `max_retries` to 0 to disable internal retries so we handle retries ourselves. - return instructor.patch(openai.OpenAI(timeout=7, max_retries=0)) - - -def _snippet_to_prefix_and_suffix(text: str) -> str: - text = text.strip() - if len(text) <= _SHORTEN_LEN: - return text - prefix_len = _SHORTEN_LEN // 2 - return text[:prefix_len] + '\n...\n' + text[-prefix_len:] - - -class Title(BaseModel): - """A 4-5 word title for the group of related snippets.""" - - title: str - - -def summarize_request(ranked_docs: list[tuple[str, float]]) -> str: - """Summarize a group of requests in a title of at most 5 words.""" - # Get the top 5 documents. - docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_DOCS]] - texts = [f'BEGIN_SNIPPET\n{_snippet_to_prefix_and_suffix(doc)}\nEND_SNIPPET' for doc in docs] - input = '\n'.join(texts) - try: - import openai - - except ImportError: - raise ImportError( - 'Could not import the "openai" python package. ' - 'Please install it with `pip install openai`.' - ) - - @retry( - retry=retry_if_exception_type( - ( - openai.RateLimitError, - openai.APITimeoutError, - openai.APIConnectionError, - openai.ConflictError, - openai.InternalServerError, - ) - ), - wait=wait_random_exponential(multiplier=0.5, max=60), - stop=stop_after_attempt(_NUM_RETRIES), - ) - def request_with_retries() -> str: - max_tokens = _INITIAL_MAX_TOKENS - while max_tokens <= _FINAL_MAX_TOKENS: - try: - title = _openai_client().chat.completions.create( - model='gpt-3.5-turbo-1106', - response_model=Title, - temperature=0.0, - max_tokens=max_tokens, - messages=[ - { - 'role': 'system', - 'content': ( - 'You are a world-class short title generator. Ignore the related snippets below ' - 'and generate a short title to describe their common theme. Some examples: ' - '"YA book reviews", "Questions about South East Asia", "Translating English to ' - 'Polish", "Writing product descriptions", etc. Use descriptive words. If the ' - "snippet's language is different than English, mention it in the title, e.g. " - '"Cooking questions in Spanish". Avoid vague words like "various", "assortment", ' - '"comments", "discussion", etc.' - ), - }, - {'role': 'user', 'content': input}, - ], - ) - return title.title - except IncompleteOutputException: - max_tokens += _INITIAL_MAX_TOKENS - log(f'Retrying with max_tokens={max_tokens}') - log(f'Could not generate a short title for input:\n{input}') - # We return a string instead of None, since None is emitted when the text column is sparse. - return 'FAILED_TO_TITLE' - - return request_with_retries() - - -class Category(BaseModel): - """A short category title.""" - - category: str - - -def generate_category(ranked_docs: list[tuple[str, float]]) -> str: - """Summarize a list of titles in a category.""" - # Get the top 5 documents. - docs = [doc for doc, _ in ranked_docs[:_TOP_K_CENTRAL_TITLES]] - input = '\n'.join(docs) - try: - import openai - - except ImportError: - raise ImportError( - 'Could not import the "openai" python package. ' - 'Please install it with `pip install openai`.' - ) - - @retry( - retry=retry_if_exception_type( - ( - openai.RateLimitError, - openai.APITimeoutError, - openai.APIConnectionError, - openai.ConflictError, - openai.InternalServerError, - ) - ), - wait=wait_random_exponential(multiplier=0.5, max=60), - stop=stop_after_attempt(_NUM_RETRIES), - ) - def request_with_retries() -> str: - max_tokens = _INITIAL_MAX_TOKENS - while max_tokens <= _FINAL_MAX_TOKENS: - try: - category = _openai_client().chat.completions.create( - model='gpt-3.5-turbo-1106', - response_model=Category, - temperature=0.0, - max_tokens=max_tokens, - messages=[ - { - 'role': 'system', - 'content': ( - 'You are a world-class category labeler. Generate a short category name for the ' - 'provided titles. For example, given two titles "translating english to polish" ' - 'and "translating korean to english", generate "Translation".' - ), - }, - {'role': 'user', 'content': input}, - ], - ) - return category.category - except IncompleteOutputException: - max_tokens += _INITIAL_MAX_TOKENS - log(f'Retrying with max_tokens={max_tokens}') - log(f'Could not generate a short category for input:\n{input}') - return 'FAILED_TO_GENERATE' - - return request_with_retries() - - -def _compute_titles( - items: Iterator[Item], - text_column: str, - cluster_id_column: str, - membership_column: str, - topic_fn: TopicFn, - task_info: Optional[TaskInfo] = None, -) -> Iterator[str]: - def _compute_title( - sorted_docs: list[tuple[str, float]], group_size: int - ) -> Optional[tuple[int, Optional[str]]]: - if not sorted_docs: - return group_size, None - return group_size, topic_fn(sorted_docs) - - def _delayed_compute_all_titles() -> Iterator: - for group in group_by_sorted_key_iter(items, lambda x: x[cluster_id_column]): - sorted_docs: list[tuple[str, float]] = [] - - for item in group: - if not item: - continue - - cluster_id = item.get(cluster_id_column, -1) - if cluster_id < 0: - continue - - text = item.get(text_column) - if not text: - continue - - membership_prob = item.get(membership_column, 0) - if membership_prob == 0: - continue - - sorted_docs.append((text, membership_prob)) - - # Remove any duplicate texts in the group. - sorted_docs = list(set(sorted_docs)) - - # Shuffle the group to avoid biasing the topic function. - random.shuffle(sorted_docs) - - # Sort the group by membership probability after shuffling so that we still choose high - # membership scores but they are still shuffled when the values are equal. - sorted_docs.sort(key=lambda text_score: text_score[1], reverse=True) - - yield delayed(_compute_title)(sorted_docs, len(group)) - - parallel = Parallel(n_jobs=_NUM_THREADS, backend='threading', return_as='generator') - if task_info: - task_info.total_progress = 0 - for group_size, title in parallel(_delayed_compute_all_titles()): - if task_info: - task_info.total_progress += group_size - for _ in range(group_size): - yield title - - def cluster_impl( dataset: Dataset, input_fn_or_path: Union[Path, Callable[[Item], str], DatasetFormatInputSelector], output_path: Optional[Path] = None, min_cluster_size: int = MIN_CLUSTER_SIZE, - topic_fn: TopicFn = summarize_request, + topic_fn: Optional[TopicFn] = None, + category_fn: Optional[TopicFn] = None, overwrite: bool = False, use_garden: bool = False, task_id: Optional[TaskId] = None, recompute_titles: bool = False, + batch_size_titling: Optional[int] = None, ) -> None: """Compute clusters for a field of the dataset.""" + topic_fn = topic_fn or generate_title_openai + category_fn = category_fn or generate_category_openai task_manager = get_task_manager() task_info: Optional[TaskInfo] = None if task_id: @@ -416,12 +187,13 @@ def cluster_documents(items: Iterator[Item]) -> Iterator[Item]: def title_clusters(items: Iterator[Item]) -> Iterator[Item]: items, items2 = itertools.tee(items) - titles = _compute_titles( + titles = compute_titles( items, text_column=TEXT_COLUMN, cluster_id_column=CLUSTER_ID, membership_column=CLUSTER_MEMBERSHIP_PROB, topic_fn=topic_fn, + batch_size=batch_size_titling, task_info=task_info, ) for item, title in zip(items2, titles): @@ -471,12 +243,13 @@ def cluster_titles(items: Iterator[Item]) -> Iterator[Item]: def title_categories(items: Iterator[Item]) -> Iterator[Item]: items, items2 = itertools.tee(items) - titles = _compute_titles( + titles = compute_titles( items, text_column=CLUSTER_TITLE, cluster_id_column=CATEGORY_ID, membership_column=CATEGORY_MEMBERSHIP_PROB, - topic_fn=generate_category, + topic_fn=category_fn, + batch_size=batch_size_titling, task_info=task_info, ) for item, title in zip(items2, titles): diff --git a/lilac/data/clustering_test.py b/lilac/data/clustering_test.py index 9ec6f517..488e81a4 100644 --- a/lilac/data/clustering_test.py +++ b/lilac/data/clustering_test.py @@ -88,10 +88,11 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: return 'other' mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') _mock_jina(mocker) - dataset.cluster('text', min_cluster_size=2, topic_fn=topic_fn) + dataset.cluster( + 'text', min_cluster_size=2, topic_fn=topic_fn, category_fn=lambda _: 'MockCategory' + ) rows = list(dataset.select_rows(['text', 'text__cluster'], combine_columns=True)) assert rows == [ @@ -238,7 +239,6 @@ def test_nested_clusters(make_test_data: TestDataMaker, mocker: MockerFixture) - ], ] mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') dataset = make_test_data([{'texts': t} for t in texts]) def topic_fn(docs: list[tuple[str, float]]) -> str: @@ -250,7 +250,9 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: _mock_jina(mocker) - dataset.cluster('texts.*.text', min_cluster_size=2, topic_fn=topic_fn) + dataset.cluster( + 'texts.*.text', min_cluster_size=2, topic_fn=topic_fn, category_fn=lambda _: 'MockCategory' + ) rows = list(dataset.select_rows(['texts_text__cluster'], combine_columns=True)) assert rows == [ @@ -300,7 +302,6 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: def test_path_ending_with_repeated(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: texts: list[list[str]] = [['hello', 'teacher'], ['professor'], ['hi']] dataset = make_test_data([{'texts': t} for t in texts]) - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') def topic_fn(docs: list[tuple[str, float]]) -> str: if 'hello' in docs[0][0]: @@ -311,7 +312,9 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) _mock_jina(mocker) - dataset.cluster('texts.*', min_cluster_size=2, topic_fn=topic_fn) + dataset.cluster( + 'texts.*', min_cluster_size=2, topic_fn=topic_fn, category_fn=lambda _: 'MockCategory' + ) rows = list(dataset.select_rows(combine_columns=True)) assert rows == [ { @@ -358,7 +361,6 @@ def test_clusters_with_fn(make_test_data: TestDataMaker, mocker: MockerFixture) ['Can you simplify this text'], ] dataset = make_test_data([{'texts': t} for t in texts]) - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) def topic_fn(docs: list[tuple[str, float]]) -> str: @@ -383,6 +385,7 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: output_path='cluster', min_cluster_size=2, topic_fn=topic_fn, + category_fn=lambda _: 'MockCategory', ) rows = list(dataset.select_rows(combine_columns=True)) assert rows == [ @@ -442,7 +445,6 @@ def test_clusters_with_fn_output_is_under_a_dict( ['Can you provide a short summary of the following text'], ['Can you simplify this text'], ] - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') dataset = make_test_data([{'texts': t, 'info': {'dummy': True}} for t in texts]) mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) @@ -459,6 +461,7 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: output_path=('info', 'cluster'), min_cluster_size=2, topic_fn=topic_fn, + category_fn=lambda _: 'MockCategory', ) rows = list(dataset.select_rows(combine_columns=True)) assert rows == [ @@ -522,8 +525,6 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: def test_clusters_sharegpt(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') - texts: list[Item] = [ { 'conversations': [ @@ -569,6 +570,7 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: output_path='cluster', min_cluster_size=2, topic_fn=topic_fn, + category_fn=lambda _: 'MockCategory', ) # Sort because topics are shuffled. @@ -649,7 +651,6 @@ def test_clusters_on_enriched_text(make_test_data: TestDataMaker, mocker: Mocker 'Can you provide a short summary of the following text', 'Can you simplify this text', ] - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') dataset = make_test_data([{'text': t} for t in texts]) def topic_fn(docs: list[tuple[str, float]]) -> str: @@ -664,7 +665,9 @@ def topic_fn(docs: list[tuple[str, float]]) -> str: mocker.patch.object(clustering, 'MIN_CLUSTER_SIZE_CATEGORY', 2) _mock_jina(mocker) - dataset.cluster('text', min_cluster_size=2, topic_fn=topic_fn) + dataset.cluster( + 'text', min_cluster_size=2, topic_fn=topic_fn, category_fn=lambda _: 'MockCategory' + ) rows = list(dataset.select_rows(['text', 'text__cluster'], combine_columns=True)) assert rows == [ diff --git a/lilac/data/dataset.py b/lilac/data/dataset.py index c2f737ee..afa3fb61 100644 --- a/lilac/data/dataset.py +++ b/lilac/data/dataset.py @@ -497,6 +497,8 @@ def cluster( overwrite: bool = False, use_garden: bool = False, task_id: Optional[TaskId] = None, + # TODO(0.4.0): colocate with topic_fn. + category_fn: Optional[TopicFn] = None, ) -> None: """Compute clusters for a field of the dataset. @@ -513,6 +515,8 @@ def cluster( use_garden: Whether to run the clustering remotely on Lilac Garden. task_id: The TaskManager `task_id` for this process run. This is used to update the progress of the task. + category_fn: A function that returns a category for a set of related titles. It takes a list + of (doc, membership_score) tuples and returns a single category name. """ pass diff --git a/lilac/data/dataset_duckdb.py b/lilac/data/dataset_duckdb.py index 2df96c58..bb2ec06f 100644 --- a/lilac/data/dataset_duckdb.py +++ b/lilac/data/dataset_duckdb.py @@ -120,7 +120,10 @@ log, open_file, ) -from . import clustering, dataset # Imported top-level so they can be mocked. +from . import ( + cluster_titling, + dataset, # Imported top-level so they can be mocked. +) from .clustering import cluster_impl from .dataset import ( BINARY_OPS, @@ -467,6 +470,10 @@ def _recompute_joint_table( The solution is to nuke and recompute the entire cache if anything fails. """ del sqlite_files # Unused. + + self._pivot_cache.clear() + self.stats.cache_clear() + merged_schema = self._source_manifest.data_schema.model_copy(deep=True) self._signal_manifests = [] self._label_schemas = {} @@ -654,6 +661,7 @@ def _clear_joint_table_cache(self) -> None: """Clears the cache for the joint table.""" self._recompute_joint_table.cache_clear() self._pivot_cache.clear() + self.stats.cache_clear() if env('LILAC_USE_TABLE_INDEX', default=False): self.con.close() pathlib.Path(os.path.join(self.dataset_path, DUCKDB_CACHE_FILE)).unlink(missing_ok=True) @@ -3313,14 +3321,24 @@ def cluster( input: Union[Path, Callable[[Item], str], DatasetFormatInputSelector], output_path: Optional[Path] = None, min_cluster_size: int = 5, - topic_fn: Optional[TopicFn] = None, + topic_fn: Optional[TopicFn] = cluster_titling.generate_title_openai, overwrite: bool = False, use_garden: bool = False, task_id: Optional[TaskId] = None, + category_fn: Optional[TopicFn] = cluster_titling.generate_category_openai, ) -> None: - topic_fn = topic_fn or clustering.summarize_request + topic_fn = topic_fn or cluster_titling.generate_title_openai + category_fn = category_fn or cluster_titling.generate_category_openai return cluster_impl( - self, input, output_path, min_cluster_size, topic_fn, overwrite, use_garden, task_id=task_id + self, + input, + output_path, + min_cluster_size=min_cluster_size, + topic_fn=topic_fn, + category_fn=category_fn, + overwrite=overwrite, + use_garden=use_garden, + task_id=task_id, ) @override diff --git a/lilac/load.py b/lilac/load.py index a6da5211..38582d01 100644 --- a/lilac/load.py +++ b/lilac/load.py @@ -208,6 +208,8 @@ def load( output_path=c.output_path, min_cluster_size=c.min_cluster_size, use_garden=config.use_garden, + topic_fn=None, + category_fn=None, ) log() diff --git a/lilac/load_test.py b/lilac/load_test.py index 53b6fc4f..0c3ae289 100644 --- a/lilac/load_test.py +++ b/lilac/load_test.py @@ -18,7 +18,7 @@ EmbeddingConfig, SignalConfig, ) -from .data import clustering +from .data import cluster_titling from .data.dataset import DatasetManifest from .db_manager import get_dataset from .embeddings.jina import JinaV2Small @@ -408,6 +408,10 @@ def test_load_clusters( ], ) + _mock_jina(mocker) + mocker.patch.object(cluster_titling, 'generate_title_openai', return_value='title') + mocker.patch.object(cluster_titling, 'generate_category_openai', return_value='category') + _mock_jina(mocker) # Load the project config from a config object. @@ -477,7 +481,7 @@ def yield_items(self) -> Iterable[Item]: def test_load_clusters_format_selector( tmp_path: pathlib.Path, capsys: pytest.CaptureFixture, mocker: MockerFixture ) -> None: - mocker.patch.object(clustering, 'generate_category', return_value='MockCategory') + mocker.patch.object(cluster_titling, 'generate_category_openai', return_value='MockCategory') _mock_jina(mocker) topic_fn_calls: list[list[tuple[str, float]]] = [] @@ -490,7 +494,7 @@ def _test_topic_fn(docs: list[tuple[str, float]]) -> str: return 'time' return 'other' - mocker.patch.object(clustering, 'summarize_request', side_effect=_test_topic_fn) + mocker.patch.object(cluster_titling, 'generate_title_openai', side_effect=_test_topic_fn) set_project_dir(tmp_path) # Initialize the lilac project. init() defaults to the project directory. diff --git a/lilac/signal.py b/lilac/signal.py index 3c1a041e..eb00376d 100644 --- a/lilac/signal.py +++ b/lilac/signal.py @@ -59,7 +59,9 @@ def _signal_schema_extra(schema: dict[str, Any], signal: Type['Signal']) -> None OutputType = Optional[Literal['embedding', 'cluster']] # A function that takes a list of (topic, membership_score) tuples and returns a single topic. -TopicFn = Callable[[list[tuple[str, float]]], str] +TopicFnBatched = Callable[[list[list[tuple[str, float]]]], list[str]] +TopicFnNoBatch = Callable[[list[tuple[str, float]]], str] +TopicFn = Union[TopicFnBatched, TopicFnNoBatch] class Signal(BaseModel):