Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import typing
from collections.abc import Mapping, Sequence
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, overload
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, overload


def _depth_first_search(lst: List[str], visited: Set[str], graph: Dict[str, List[str]], node: str):
Expand Down Expand Up @@ -84,6 +84,10 @@ def append(self, annotation: T) -> None:
annotation.set_target(getattr(self._document, self._target))
self._annotations.append(annotation)

def extend(self, annotations: Iterable[T]) -> None:
for annotation in annotations:
self.append(annotation)

def __repr__(self) -> str:
return f"BaseAnnotationList({str(self._annotations)})"

Expand All @@ -102,6 +106,12 @@ def __init__(self, document: "Document", target: "str"):
def predictions(self) -> BaseAnnotationList[T]:
return self._predictions

def integrate_predictions(self, overwrite: bool = False):
if overwrite:
self.clear()
self.extend(self.predictions)
self._predictions._annotations = []

def __eq__(self, other: object) -> bool:
if not isinstance(other, AnnotationList):
return NotImplemented
Expand Down
33 changes: 33 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,36 @@ class TestDocument(TextDocument):

# TODO: revisit when we decided how to handle serialization of predictions
# assert document1 == TestDocument.fromdict(document1.asdict())


@pytest.mark.parametrize("overwrite", [False, True])
def test_integrate_annotations(overwrite):
@dataclasses.dataclass
class TestDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")

document = TestDocument(text="Entity A works at B.")

entity1 = LabeledSpan(start=0, end=8, label="PER")
entity2 = LabeledSpan(start=18, end=19, label="ORG")

document.entities.append(entity1)
document.entities.predictions.append(entity2)

assert len(document.entities) == 1
assert document.entities[0] == entity1

assert len(document.entities.predictions) == 1
assert document.entities.predictions[0] == entity2

document.entities.integrate_predictions(overwrite=overwrite)

if overwrite:
assert len(document.entities) == 1
assert document.entities[0] == entity2
else:
assert len(document.entities) == 2
assert document.entities[0] == entity1
assert document.entities[1] == entity2

assert len(document.entities.predictions) == 0