Skip to content

Commit 3dcbb4c

Browse files
authored
implement target getters for BaseAnnotationList (#297)
* implement target_layers and target_fields properties for BaseAnnotationList * rename target_fields to targets; remove hasattr check; add target and target_layer * add tests for targets, target, target_layers, and target_layer
1 parent 92c47ca commit 3dcbb4c

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/pytorch_ie/core/document.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,39 @@ def pop(self, index: int = -1) -> T:
351351
ann.set_targets(None)
352352
return ann
353353

354+
@property
355+
def targets(self) -> dict[str, Any]:
356+
return {
357+
target_field_name: getattr(self._document, target_field_name)
358+
for target_field_name in self._targets
359+
}
360+
361+
@property
362+
def target(self) -> Any:
363+
tgts = self.targets
364+
if len(tgts) != 1:
365+
raise ValueError(
366+
f"The annotation layer has more or less than one target: {self._targets}"
367+
)
368+
return list(tgts.values())[0]
369+
370+
@property
371+
def target_layers(self) -> dict[str, "AnnotationList"]:
372+
return {
373+
target_name: target
374+
for target_name, target in self.targets.items()
375+
if isinstance(target, AnnotationList)
376+
}
377+
378+
@property
379+
def target_layer(self) -> "AnnotationList":
380+
tgt_layers = self.target_layers
381+
if len(tgt_layers) != 1:
382+
raise ValueError(
383+
f"The annotation layer has more or less than one target layer: {list(tgt_layers.keys())}"
384+
)
385+
return list(tgt_layers.values())[0]
386+
354387

355388
class AnnotationList(BaseAnnotationList[T]):
356389
def __init__(self, document: "Document", targets: List["str"]):

tests/test_document.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,56 @@ class TestDocument(Document):
504504
),
505505
):
506506
doc = TestDocument(text="text1")
507+
508+
509+
def test_annotation_list_targets():
510+
@dataclasses.dataclass
511+
class TestDocument(Document):
512+
text: str
513+
entities1: AnnotationList[LabeledSpan] = annotation_field(target="text")
514+
entities2: AnnotationList[LabeledSpan] = annotation_field(target="text")
515+
relations1: AnnotationList[BinaryRelation] = annotation_field(target="entities1")
516+
relations2: AnnotationList[BinaryRelation] = annotation_field(
517+
targets=["entities1", "entities2"]
518+
)
519+
520+
doc = TestDocument(text="text1")
521+
522+
# test getting all targets
523+
assert doc.entities1.targets == {"text": doc.text}
524+
assert doc.entities2.targets == {"text": doc.text}
525+
assert doc.relations1.targets == {"entities1": doc.entities1}
526+
assert doc.relations2.targets == {"entities1": doc.entities1, "entities2": doc.entities2}
527+
528+
# test getting a single target
529+
assert doc.entities1.target == doc.text
530+
assert doc.entities2.target == doc.text
531+
assert doc.relations1.target == doc.entities1
532+
# check that the target of relations2 is not set because it has more than one target
533+
with pytest.raises(ValueError) as excinfo:
534+
doc.relations2.target
535+
assert (
536+
str(excinfo.value)
537+
== "The annotation layer has more or less than one target: ['entities1', 'entities2']"
538+
)
539+
540+
# test getting all target layers
541+
assert doc.entities1.target_layers == {}
542+
assert doc.entities2.target_layers == {}
543+
assert doc.relations1.target_layers == {"entities1": doc.entities1}
544+
assert doc.relations2.target_layers == {"entities1": doc.entities1, "entities2": doc.entities2}
545+
546+
# test getting a single target layer
547+
with pytest.raises(ValueError) as excinfo:
548+
doc.entities1.target_layer
549+
assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []"
550+
with pytest.raises(ValueError) as excinfo:
551+
doc.entities2.target_layer
552+
assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []"
553+
assert doc.relations1.target_layer == doc.entities1
554+
with pytest.raises(ValueError) as excinfo:
555+
doc.relations2.target_layer
556+
assert (
557+
str(excinfo.value)
558+
== "The annotation layer has more or less than one target layer: ['entities1', 'entities2']"
559+
)

0 commit comments

Comments
 (0)