Skip to content

Commit 7bcf70f

Browse files
authored
Document.add_all_annotations_from_other() returns added original annotations (#390)
1 parent ced53f6 commit 7bcf70f

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

src/pytorch_ie/core/document.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,13 @@ def add_all_annotations_from_other(
714714
process_predictions: bool = True,
715715
strict: bool = True,
716716
verbose: bool = True,
717-
) -> None:
718-
"""Adds all annotations from another document to this document.
717+
) -> Dict[str, List[Annotation]]:
718+
"""Adds all annotations from another document to this document. It allows to blacklist annotations
719+
and also to override annotations. It returns the original annotations for which a new annotation was
720+
added to the current document.
721+
722+
The method is useful if e.g. a text-based document is converted to a token-based document and the
723+
annotations should be added to the token-based document.
719724
720725
Args:
721726
other: The document to add annotations from.
@@ -744,6 +749,11 @@ def add_all_annotations_from_other(
744749
verbose: Whether to print a warning if the other document contains annotations that reference
745750
annotations that are not present in the current document (see parameter removed_annotations).
746751
752+
Returns:
753+
A mapping from annotation field names to the set of annotations from the original document for which
754+
a new annotation was added to the current document. This can be useful to check if all original
755+
annotations were added (possibly to multiple target documents).
756+
747757
Example:
748758
```
749759
@dataclasses.dataclass(frozen=True)
@@ -787,6 +797,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
787797
```
788798
"""
789799
removed_annotations = defaultdict(set, removed_annotations or dict())
800+
added_annotations = defaultdict(list)
790801

791802
annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
792803
named_annotation_fields = {field.name: field for field in self.annotation_fields()}
@@ -829,6 +840,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
829840
if ann._id != new_ann._id:
830841
annotation_store[field_name][ann._id] = new_ann
831842
self[field_name].append(new_ann)
843+
added_annotations[field_name].append(ann)
832844
else:
833845
if strict:
834846
raise ValueError(
@@ -853,6 +865,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
853865
if ann._id != new_ann._id:
854866
annotation_store[field_name][ann._id] = new_ann
855867
self[field_name].predictions.append(new_ann)
868+
added_annotations[field_name].append(ann)
856869
else:
857870
if strict:
858871
raise ValueError(
@@ -868,6 +881,8 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
868881
# The annotation was removed, so we need to make sure that it is not referenced by any other
869882
removed_annotations[field_name].add(ann._id)
870883

884+
return dict(added_annotations)
885+
871886

872887
def resolve_annotation(
873888
id_or_annotation: Union[int, Annotation],

tests/test_document.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,20 @@ class TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(TextDocument):
644644

645645
def test_document_extend_from_other_full_copy(text_document):
646646
doc_new = type(text_document)(text=text_document.text)
647-
doc_new.add_all_annotations_from_other(text_document)
647+
added_annotations = doc_new.add_all_annotations_from_other(text_document)
648648

649649
assert text_document.asdict() == doc_new.asdict()
650+
assert set(added_annotations) == {
651+
"entities1",
652+
"entities2",
653+
"relations",
654+
"relation_attributes",
655+
"labels",
656+
}
657+
for layer_name, annotation_set in added_annotations.items():
658+
assert len(annotation_set) > 0
659+
available_annotations = text_document[layer_name]
660+
assert annotation_set == list(available_annotations)
650661

651662

652663
def test_document_extend_from_other_wrong_override_annotation_mapping(text_document):
@@ -683,9 +694,15 @@ class TestDocument2(TokenBasedDocument):
683694
token_document.entities1.append(e1_new)
684695
token_document.entities2.append(e2_new)
685696
# ... and the remaining annotations
686-
token_document.add_all_annotations_from_other(
697+
added_annotations = token_document.add_all_annotations_from_other(
687698
text_document, override_annotations=annotation_mapping
688699
)
700+
# check that the added annotations are as expected (the entity annotations are already there)
701+
assert added_annotations == {
702+
"relations": list(text_document.relations),
703+
"relation_attributes": list(text_document.relation_attributes),
704+
"labels": list(text_document.labels),
705+
}
689706

690707
assert (
691708
len(token_document.entities1)
@@ -710,12 +727,18 @@ class TestDocument2(TokenBasedDocument):
710727

711728
def test_document_extend_from_other_remove(text_document):
712729
doc_new = type(text_document)(text=text_document.text)
713-
doc_new.add_all_annotations_from_other(
730+
added_annotations = doc_new.add_all_annotations_from_other(
714731
text_document,
715732
removed_annotations={"entities1": {text_document.entities1[0]._id}},
716733
strict=False,
717734
)
718735

736+
# the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well
737+
assert added_annotations == {
738+
"entities2": list(text_document.entities2),
739+
"labels": list(text_document.labels),
740+
}
741+
719742
assert len(doc_new.entities1) == 0
720743
assert len(doc_new.entities2) == 1
721744
assert len(doc_new.relations) == 0

0 commit comments

Comments
 (0)