Skip to content

Commit 3e3f2aa

Browse files
committed
deduplicate annotations in pipeline
1 parent 32e6f9e commit 3e3f2aa

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TaskModule,
2222
TaskOutput,
2323
)
24+
from pytorch_ie.utils.document import deduplicate_annotations
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -243,18 +244,22 @@ def postprocess(
243244
self,
244245
model_inputs: Sequence[TaskEncoding],
245246
model_outputs: Sequence[TaskOutput],
247+
deduplicate_annotations: bool = False,
246248
**postprocess_parameters,
247249
) -> Sequence[Document]:
248250
"""
249251
Postprocess will receive the model inputs and (unbatched) model outputs and reformat them into
250252
something more friendly. Generally it will output a list of documents.
251253
"""
252254
# This creates annotations from the model outputs and attaches them to the correct documents.
253-
return self.taskmodule.decode(
255+
result = self.taskmodule.decode(
254256
task_encodings=model_inputs,
255257
task_outputs=model_outputs,
256258
**postprocess_parameters,
257259
)
260+
if deduplicate_annotations:
261+
result = [document.deduplicate_annotations() for document in result]
262+
return result
258263

259264
def get_inference_context(self):
260265
inference_context = (
@@ -308,12 +313,6 @@ def __call__(
308313
postprocess_params,
309314
) = self._sanitize_parameters(**kwargs)
310315

311-
in_place: bool = postprocess_params.get("inplace", True)
312-
if in_place and not isinstance(documents, (MutableSequence, Document)):
313-
raise InplaceNotSupportedException(
314-
"Immutable sequences of Documents (such as Datasets) can't be modified in place. Please set inplace=False."
315-
)
316-
317316
if "TOKENIZERS_PARALLELISM" not in os.environ:
318317
logger.info(
319318
"Disabling tokenizer parallelism, we're using DataLoader multithreading already"
@@ -326,6 +325,16 @@ def __call__(
326325
forward_params = {**self._forward_params, **forward_params}
327326
postprocess_params = {**self._postprocess_params, **postprocess_params}
328327

328+
in_place: bool = postprocess_params.get("inplace", True)
329+
if in_place and not isinstance(documents, (MutableSequence, Document)):
330+
raise InplaceNotSupportedException(
331+
"Immutable sequences of Documents (such as Datasets) can't be modified in place. Please set inplace=False."
332+
)
333+
if postprocess_params.get("deduplicate_annotations", False) and in_place:
334+
raise ValueError(
335+
"Deduplicating annotations requires inplace=False. Please set inplace=False."
336+
)
337+
329338
single_document = False
330339
if isinstance(documents, Document):
331340
single_document = True

0 commit comments

Comments
 (0)