Skip to content

Commit

Permalink
add annotation_processor to F1Metric (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Jan 24, 2024
1 parent 8d63ad7 commit 280beb8
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/pytorch_ie/metrics/f1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from collections import defaultdict
from functools import partial
from typing import Callable, Collection, Dict, Optional, Tuple, Union
from typing import Callable, Collection, Dict, Hashable, Optional, Tuple, Union

import pandas as pd

from pytorch_ie.core import Annotation, Document, DocumentMetric
from pytorch_ie.utils.hydra import resolve_target

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,11 +36,15 @@ def __init__(
labels: Optional[Union[Collection[str], str]] = None,
label_field: str = "label",
show_as_markdown: bool = False,
annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
):
super().__init__()
self.layer = layer
self.label_field = label_field
self.show_as_markdown = show_as_markdown
if isinstance(annotation_processor, str):
annotation_processor = resolve_target(annotation_processor)
self.annotation_processor = annotation_processor

self.per_label = labels is not None
self.infer_labels = False
Expand Down Expand Up @@ -71,12 +76,18 @@ def calculate_counts(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> Tuple[int, int, int]:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
ann for ann in document[self.layer].predictions if annotation_filter(ann)
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann)
}
gold_annotations = {
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
}
gold_annotations = {ann for ann in document[self.layer] if annotation_filter(ann)}
tp = len([ann for ann in predicted_annotations & gold_annotations])
fn = len([ann for ann in gold_annotations - predicted_annotations])
fp = len([ann for ann in predicted_annotations - gold_annotations])
Expand All @@ -97,6 +108,7 @@ def _update(self, document: Document):
)
if self.per_label and not self.infer_labels
else None,
annotation_processor=self.annotation_processor,
)
self.add_counts(new_counts, label="MICRO")
if self.infer_labels:
Expand All @@ -111,6 +123,7 @@ def _update(self, document: Document):
annotation_filter=partial(
has_this_label, label_field=self.label_field, label=label
),
annotation_processor=self.annotation_processor,
)
self.add_counts(new_counts, label=label)

Expand Down

0 comments on commit 280beb8

Please sign in to comment.