|
7 | 7 | from textwrap import dedent
|
8 | 8 | from typing import List, Optional
|
9 | 9 |
|
10 |
| -from chonkie import BaseChunker, BaseEmbeddings |
| 10 | +import nest_asyncio |
| 11 | +from loguru import logger |
| 12 | +from pydantic import Field, field_validator |
| 13 | +from symai import Import, Symbol |
| 14 | +from symai.components import FileReader, Function |
| 15 | +from symai.core_ext import bind |
| 16 | +from symai.models import LLMDataModel |
11 | 17 | from tenacity import (
|
12 | 18 | before_sleep_log,
|
13 | 19 | retry,
|
| 20 | + retry_if_exception_type, |
14 | 21 | stop_after_attempt,
|
15 | 22 | wait_exponential_jitter,
|
16 |
| - retry_if_exception_type, |
17 | 23 | )
|
18 |
| -import nest_asyncio |
19 |
| -from loguru import logger |
20 |
| -from pydantic import Field, field_validator |
21 |
| -import tiktoken |
22 | 24 | from tiktoken import Encoding
|
23 | 25 | from tokenizers import Tokenizer
|
24 |
| -from chonkie import RecursiveChunker |
25 |
| -from symai.components import FileReader, Function |
26 |
| -from symai.core_ext import bind |
27 |
| -from symai.models import LLMDataModel |
28 | 26 |
|
29 | 27 | from .functions import ValidatedFunction
|
30 | 28 | from .types import TYPE_SPECIFIC_PROMPTS, DocumentType
|
31 | 29 |
|
| 30 | +# Load the chunker |
| 31 | +ChonkieChunker = Import.load_expression("ExtensityAI/chonkie-symai", "ChonkieChunker") |
| 32 | + |
32 | 33 |
|
33 | 34 | class Summary(LLMDataModel):
|
34 | 35 | summary: str = Field(
|
@@ -112,8 +113,8 @@ def __init__(
|
112 | 113 | max_output_tokens: int = 10000,
|
113 | 114 | user_prompt: str = None,
|
114 | 115 | include_quotes: bool = False,
|
115 |
| - tokenizer: str | BaseEmbeddings | Encoding = "gpt2", |
116 |
| - chunker: BaseChunker = RecursiveChunker, |
| 116 | + tokenizer_name: str = "gpt2", |
| 117 | + chunker_name: str = "RecursiveChunker", |
117 | 118 | seed: int = 42,
|
118 | 119 | *args,
|
119 | 120 | **kwargs,
|
@@ -152,21 +153,9 @@ def __init__(
|
152 | 153 | self.content = f"[[DOCUMENT::{file_name}]]: <<<\n{str(file_content)}\n>>>\n"
|
153 | 154 | self.content_only = str(file_content)
|
154 | 155 |
|
155 |
| - # init tokenizer |
156 |
| - if isinstance(tokenizer, str): |
157 |
| - try: |
158 |
| - self.tokenizer = tiktoken.encoding_for_model(tokenizer) |
159 |
| - except: |
160 |
| - try: |
161 |
| - self.tokenizer = Tokenizer.from_pretrained(tokenizer) |
162 |
| - except: |
163 |
| - logger.warning( |
164 |
| - f"Tokenizer {tokenizer} not found, using o200k_base tokenizer instead." |
165 |
| - ) |
166 |
| - self.tokenizer = tiktoken.get_encoding('o200k_base') |
167 |
| - else: |
168 |
| - self.tokenizer = tokenizer |
169 |
| - self.chunker = chunker |
| 156 | + # init chunker |
| 157 | + self.chunker = ChonkieChunker(tokenizer_name=tokenizer_name) |
| 158 | + self.chunker_type = chunker_name |
170 | 159 |
|
171 | 160 | # Content type is unknown at initialization
|
172 | 161 | self.document_type = None
|
@@ -296,7 +285,7 @@ def split_words(self, text):
|
296 | 285 | def chunk_by_token_count(self, text, chunk_size, include_context=False):
|
297 | 286 | # prepare results
|
298 | 287 | logger.debug(f"Chunking with chunk size: {chunk_size}")
|
299 |
| - chunks = self.chunker(self.tokenizer, chunk_size=chunk_size)(text) |
| 288 | + chunks = self.chunker(data=Symbol(text), chunker_name=self.chunker_type, chunk_size=chunk_size) |
300 | 289 | logger.debug(f"Number of chunks: {len(chunks)}")
|
301 | 290 | return chunks
|
302 | 291 |
|
|
0 commit comments