Skip to content

Commit

Permalink
Document.add_all_annotations_from_other() returns added original anno…
Browse files Browse the repository at this point in the history
…tations (#390)
  • Loading branch information
ArneBinder authored Dec 11, 2023
1 parent ced53f6 commit 7bcf70f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
19 changes: 17 additions & 2 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,13 @@ def add_all_annotations_from_other(
process_predictions: bool = True,
strict: bool = True,
verbose: bool = True,
) -> None:
"""Adds all annotations from another document to this document.
) -> Dict[str, List[Annotation]]:
"""Adds all annotations from another document to this document. It allows to blacklist annotations
and also to override annotations. It returns the original annotations for which a new annotation was
added to the current document.
The method is useful if e.g. a text-based document is converted to a token-based document and the
annotations should be added to the token-based document.
Args:
other: The document to add annotations from.
Expand Down Expand Up @@ -744,6 +749,11 @@ def add_all_annotations_from_other(
verbose: Whether to print a warning if the other document contains annotations that reference
annotations that are not present in the current document (see parameter removed_annotations).
Returns:
A mapping from annotation field names to the set of annotations from the original document for which
a new annotation was added to the current document. This can be useful to check if all original
annotations were added (possibly to multiple target documents).
Example:
```
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -787,6 +797,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
```
"""
removed_annotations = defaultdict(set, removed_annotations or dict())
added_annotations = defaultdict(list)

annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
named_annotation_fields = {field.name: field for field in self.annotation_fields()}
Expand Down Expand Up @@ -829,6 +840,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].append(new_ann)
added_annotations[field_name].append(ann)
else:
if strict:
raise ValueError(
Expand All @@ -853,6 +865,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].predictions.append(new_ann)
added_annotations[field_name].append(ann)
else:
if strict:
raise ValueError(
Expand All @@ -868,6 +881,8 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
# The annotation was removed, so we need to make sure that it is not referenced by any other
removed_annotations[field_name].add(ann._id)

return dict(added_annotations)


def resolve_annotation(
id_or_annotation: Union[int, Annotation],
Expand Down
29 changes: 26 additions & 3 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,20 @@ class TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(TextDocument):

def test_document_extend_from_other_full_copy(text_document):
doc_new = type(text_document)(text=text_document.text)
doc_new.add_all_annotations_from_other(text_document)
added_annotations = doc_new.add_all_annotations_from_other(text_document)

assert text_document.asdict() == doc_new.asdict()
assert set(added_annotations) == {
"entities1",
"entities2",
"relations",
"relation_attributes",
"labels",
}
for layer_name, annotation_set in added_annotations.items():
assert len(annotation_set) > 0
available_annotations = text_document[layer_name]
assert annotation_set == list(available_annotations)


def test_document_extend_from_other_wrong_override_annotation_mapping(text_document):
Expand Down Expand Up @@ -683,9 +694,15 @@ class TestDocument2(TokenBasedDocument):
token_document.entities1.append(e1_new)
token_document.entities2.append(e2_new)
# ... and the remaining annotations
token_document.add_all_annotations_from_other(
added_annotations = token_document.add_all_annotations_from_other(
text_document, override_annotations=annotation_mapping
)
# check that the added annotations are as expected (the entity annotations are already there)
assert added_annotations == {
"relations": list(text_document.relations),
"relation_attributes": list(text_document.relation_attributes),
"labels": list(text_document.labels),
}

assert (
len(token_document.entities1)
Expand All @@ -710,12 +727,18 @@ class TestDocument2(TokenBasedDocument):

def test_document_extend_from_other_remove(text_document):
doc_new = type(text_document)(text=text_document.text)
doc_new.add_all_annotations_from_other(
added_annotations = doc_new.add_all_annotations_from_other(
text_document,
removed_annotations={"entities1": {text_document.entities1[0]._id}},
strict=False,
)

# the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well
assert added_annotations == {
"entities2": list(text_document.entities2),
"labels": list(text_document.labels),
}

assert len(doc_new.entities1) == 0
assert len(doc_new.entities2) == 1
assert len(doc_new.relations) == 0
Expand Down

0 comments on commit 7bcf70f

Please sign in to comment.