Skip to content

Commit c156b3a

Browse files
authored
implement utils.document.tokenize_document (#497)
taken from from `pie_modules.document.processing`. This implements #462.
1 parent 4de30aa commit c156b3a

File tree

2 files changed

+736
-2
lines changed

2 files changed

+736
-2
lines changed

src/pytorch_ie/utils/document.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1+
import functools
2+
import json
3+
import logging
14
from collections import defaultdict
2-
from typing import Dict, Hashable, List, Optional, TypeVar
5+
from copy import copy
6+
from typing import Callable, Dict, Hashable, Iterable, List, Optional, Type, TypeVar
37

4-
from pie_core import Document
8+
from pie_core import Annotation, Document
59
from pie_core.document import BaseAnnotationList
10+
from pie_documents.annotations import Span
11+
from pie_documents.document.processing import text_based_document_to_token_based
12+
from pie_documents.documents import TextBasedDocument, TokenBasedDocument
13+
from transformers import PreTrainedTokenizer
614

715
from pytorch_ie.documents import WithMetadata
816

17+
logger = logging.getLogger(__name__)
18+
919

1020
def deduplicate_annotation_dicts(
1121
annotation_dicts: List[Dict[str, Hashable]]
@@ -134,3 +144,126 @@ def merge_annotations_from_documents(
134144
use_predictions=True,
135145
)
136146
return merged_document
147+
148+
149+
ToD = TypeVar("ToD", bound=TokenBasedDocument)
150+
151+
152+
def tokenize_document(
153+
doc: TextBasedDocument,
154+
tokenizer: PreTrainedTokenizer,
155+
result_document_type: Type[ToD],
156+
partition_layer: Optional[str] = None,
157+
strip_spans: bool = False,
158+
strict_span_conversion: bool = True,
159+
added_annotations: Optional[List[Dict[str, Dict[Annotation, Annotation]]]] = None,
160+
verbose: bool = True,
161+
**tokenize_kwargs,
162+
) -> List[ToD]:
163+
"""Tokenize a document with a given tokenizer and return a list of token based documents. The
164+
document is tokenized in partitions if a partition layer is provided. The annotations that
165+
target the text are converted to target the tokens and also all dependent annotations are
166+
converted.
167+
168+
Args:
169+
doc (TextBasedDocument): The document to tokenize.
170+
tokenizer (PreTrainedTokenizer): The tokenizer.
171+
result_document_type (Type[ToD]): The exact type of the token based documents.
172+
partition_layer (Optional[str], optional): The layer to use for partitioning the document. If None, the whole
173+
document is tokenized. Defaults to None.
174+
strip_spans (bool, optional): If True, strip the whitespace from the character spans before converting them to
175+
token spans. Defaults to False.
176+
strict_span_conversion (bool, optional): If True, raise an error if not all annotations can be converted to
177+
token based documents. Defaults to True.
178+
added_annotations (Optional[List[Dict[str, Dict[Annotation, Annotation]]]], optional): Pass an empty list to
179+
collect the added annotations. Defaults to None.
180+
verbose (bool, optional): If True, log warnings if annotations can not be converted. Defaults to True.
181+
182+
Returns:
183+
List[ToD]: The token based documents of type result_document_type with the converted annotations.
184+
"""
185+
186+
added_annotation_lists: Dict[str, List[Annotation]] = defaultdict(list)
187+
result = []
188+
partitions: Iterable[Span]
189+
if partition_layer is None:
190+
partitions = [Span(start=0, end=len(doc.text))]
191+
else:
192+
partitions = doc[partition_layer]
193+
for partition in partitions:
194+
text = doc.text[partition.start : partition.end]
195+
current_tokenize_kwargs = copy(tokenize_kwargs)
196+
if "text" in tokenize_kwargs:
197+
current_tokenize_kwargs["text_pair"] = text
198+
sequence_index = 1
199+
else:
200+
current_tokenize_kwargs["text"] = text
201+
sequence_index = 0
202+
tokenized_text = tokenizer(**current_tokenize_kwargs)
203+
for batch_encoding in tokenized_text.encodings:
204+
token_offset_mapping = batch_encoding.offsets
205+
char_to_token: Optional[Callable[[int], Optional[int]]]
206+
char_to_token = functools.partial(
207+
batch_encoding.char_to_token, sequence_index=sequence_index
208+
)
209+
token_offset_mapping = [
210+
offsets if s_id == sequence_index else (0, 0)
211+
for s_id, offsets in zip(batch_encoding.sequence_ids, token_offset_mapping)
212+
]
213+
if partition.start > 0:
214+
token_offset_mapping = [
215+
(start + partition.start, end + partition.start)
216+
for start, end in token_offset_mapping
217+
]
218+
char_to_token = None
219+
current_added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict)
220+
tokenized_document = text_based_document_to_token_based(
221+
doc,
222+
tokens=batch_encoding.tokens,
223+
result_document_type=result_document_type,
224+
token_offset_mapping=token_offset_mapping,
225+
char_to_token=char_to_token,
226+
strict_span_conversion=False,
227+
strip_spans=strip_spans,
228+
verbose=False,
229+
added_annotations=current_added_annotations,
230+
)
231+
tokenized_document.metadata["tokenizer_encoding"] = batch_encoding
232+
result.append(tokenized_document)
233+
for k, v in current_added_annotations.items():
234+
added_annotation_lists[k].extend(v)
235+
if added_annotations is not None:
236+
added_annotations.append(current_added_annotations)
237+
238+
missed_annotations = defaultdict(set)
239+
if strict_span_conversion or verbose:
240+
# We check the annotations with respect to the layers of the result_document_type.
241+
# Note that the original document may have more layers, but since result documents
242+
# are of type result_document_type, we only check the layers of this type.
243+
for annotation_field in result_document_type.annotation_fields():
244+
# do not check the partition layer because the partitions are not required later on
245+
# and entries get quite probably removed when windowing is applied, so this just pollutes the logs
246+
if annotation_field.name != partition_layer:
247+
current_missed_annotations = set(doc[annotation_field.name]) - set(
248+
added_annotation_lists[annotation_field.name]
249+
)
250+
if len(current_missed_annotations) > 0:
251+
missed_annotations[annotation_field.name] = current_missed_annotations
252+
253+
if len(missed_annotations) > 0:
254+
missed_annotations_simplified = {k: str(v) for k, v in missed_annotations.items()}
255+
if strict_span_conversion:
256+
raise ValueError(
257+
f"could not convert all annotations from document with id={doc.id} to token based documents, "
258+
f"but strict_span_conversion is True, so raise an error, "
259+
f"missed annotations:\n{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
260+
)
261+
else:
262+
if verbose:
263+
logger.warning(
264+
f"could not convert all annotations from document with id={doc.id} to token based documents, "
265+
f"missed annotations (disable this message with verbose=False):\n"
266+
f"{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
267+
)
268+
269+
return result

0 commit comments

Comments
 (0)