Skip to content

Commit

Permalink
allow tokenizer name or path besides instance
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas committed Dec 5, 2024
1 parent 56bbba8 commit f3064e8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
20 changes: 13 additions & 7 deletions docling_core/transforms/chunker/token_aware_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

try:
import semchunk
from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase
except ImportError:
raise RuntimeError(
"Module requires 'chunking' extra; to install, run: "
Expand All @@ -35,25 +35,31 @@ class TokenAwareChunker(BaseChunker):
r"""Token-aware chunker implementation leveraging the document layout.
Args:
tokenizer: The tokenerizer to use.
tokenizer: The tokenizer to use; either instantiated object or name or path of
respective pretrained model
max_tokens: The maximum number of tokens per chunk. If not set, limit is
resolved from the tokenizer
merge_peers: Whether to merge undersized chunks sharing same relevant metadata
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

tokenizer: PreTrainedTokenizerBase
tokenizer: Union[PreTrainedTokenizerBase, str]
max_tokens: int = None # type: ignore[assignment]
merge_peers: bool = True

_inner_chunker: HierarchicalChunker = HierarchicalChunker()

@model_validator(mode="after")
def _patch_max_tokens(self) -> Self:
def _patch_tokenizer_and_max_tokens(self) -> Self:
self._tokenizer = (
self.tokenizer
if isinstance(self.tokenizer, PreTrainedTokenizerBase)
else AutoTokenizer.from_pretrained(self.tokenizer)
)
if self.max_tokens is None:
self.max_tokens = TypeAdapter(PositiveInt).validate_python(
self.tokenizer.model_max_length
self._tokenizer.model_max_length
)
return self

Expand All @@ -65,7 +71,7 @@ def _count_tokens(self, text: Optional[Union[str, list[str]]]):
for t in text:
total += self._count_tokens(t)
return total
return len(self.tokenizer.tokenize(text, max_length=None))
return len(self._tokenizer.tokenize(text, max_length=None))

class _ChunkLengthInfo(BaseModel):
total_len: int
Expand Down Expand Up @@ -180,7 +186,7 @@ def _split_using_plain_text(
# captions:
available_length = self.max_tokens - lengths.other_len
sem_chunker = semchunk.chunkerify(
self.tokenizer, chunk_size=available_length
self._tokenizer, chunk_size=available_length
)
if available_length <= 0:
warnings.warn(
Expand Down
36 changes: 29 additions & 7 deletions test/test_token_aware_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
MAX_TOKENS = 64
INPUT_FILE = "test/data/chunker/2_inp_dl_doc.json"

TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)


def test_chunk_merge_peers():
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_chunks.json"

tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
with open(INPUT_FILE) as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)

chunker = TokenAwareChunker(
tokenizer=tokenizer,
tokenizer=TOKENIZER,
max_tokens=MAX_TOKENS,
merge_peers=True,
)
Expand All @@ -43,13 +44,12 @@ def test_chunk_merge_peers():
def test_chunk_no_merge_peers():
EXPECTED_OUT_FILE = "test/data/chunker/2b_out_chunks.json"

tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
with open(INPUT_FILE) as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)

chunker = TokenAwareChunker(
tokenizer=tokenizer,
tokenizer=TOKENIZER,
max_tokens=MAX_TOKENS,
merge_peers=False,
)
Expand All @@ -66,13 +66,12 @@ def test_chunk_no_merge_peers():
def test_serialize():
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_ser_chunks.json"

tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
with open(INPUT_FILE) as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)

chunker = TokenAwareChunker(
tokenizer=tokenizer,
tokenizer=TOKENIZER,
max_tokens=MAX_TOKENS,
merge_peers=True,
)
Expand All @@ -84,11 +83,34 @@ def test_serialize():
dict(
text=chunk.text,
ser_text=(ser_text := chunker.serialize(chunk)),
num_tokens=len(tokenizer.tokenize(ser_text, max_length=None)),
num_tokens=len(TOKENIZER.tokenize(ser_text, max_length=None)),
)
for chunk in chunks
]
)
with open(EXPECTED_OUT_FILE) as f:
exp_data = json.load(fp=f)
assert exp_data == act_data


def test_chunk_with_model_name():
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_chunks.json"

with open(INPUT_FILE) as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)

chunker = TokenAwareChunker(
tokenizer=EMBED_MODEL_ID,
max_tokens=MAX_TOKENS,
merge_peers=True,
)

chunk_iter = chunker.chunk(dl_doc=dl_doc)
chunks = list(chunk_iter)
act_data = dict(
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
)
with open(EXPECTED_OUT_FILE) as f:
exp_data = json.load(fp=f)
assert exp_data == act_data

0 comments on commit f3064e8

Please sign in to comment.