|
9 | 9 | from pytorch_ie import Dataset, IterableDataset |
10 | 10 | from pytorch_ie.annotations import BinaryRelation |
11 | 11 | from pytorch_ie.core import AnnotationList, Document |
12 | | -from pytorch_ie.core.document import BaseAnnotationList |
13 | 12 | from pytorch_ie.utils.span import is_contained_in |
14 | 13 |
|
15 | 14 | from pie_utils.document.processors.common import EnterDatasetMixin, ExitDatasetMixin |
|
20 | 19 | D = TypeVar("D", bound=Document) |
21 | 20 |
|
22 | 21 |
|
23 | | -def target_layers(layer: BaseAnnotationList) -> dict[str, AnnotationList]: |
24 | | - return { |
25 | | - target_layer_name: layer._document[target_layer_name] |
26 | | - for target_layer_name in layer._targets |
27 | | - if target_layer_name in layer._document |
28 | | - } |
29 | | - |
30 | | - |
31 | 22 | class CandidateRelationAdder(EnterDatasetMixin, ExitDatasetMixin): |
32 | 23 | """CandidateRelationAdder adds binary relations to a document based on various parameters. It |
33 | 24 | goes through combinations of available entity pairs as possible candidates for new relations. |
@@ -149,13 +140,8 @@ def __call__(self, document: D) -> D: |
149 | 140 | available_partitions = document[self.partition_layer] |
150 | 141 | else: |
151 | 142 | available_partitions = [None] |
152 | | - rel_target_layers = target_layers(layer=rel_layer) |
153 | | - if not len(rel_target_layers) == 1: |
154 | | - raise ValueError( |
155 | | - f"Relation layer must have exactly one target layer but found the following target layers: " |
156 | | - f"{list(rel_target_layers)}" |
157 | | - ) |
158 | | - entity_layer = list(rel_target_layers.values())[0] |
| 143 | + |
| 144 | + entity_layer = rel_layer.target_layer |
159 | 145 | if self.use_predictions: |
160 | 146 | entity_layer = entity_layer.predictions |
161 | 147 |
|
|
0 commit comments