From 4c2677495631d80084af14bae2608119121c2792 Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Fri, 25 Aug 2023 17:38:09 +0200 Subject: [PATCH] integrate eval_counts_for_layer() into F1Metric.calculate_counts() (#318) * integrate eval_counts_for_layer() into F1Metric.calculate_counts() so that derived classes can overwrite it * remove annotation_mapper because it is never used --- src/pytorch_ie/metrics/f1.py | 41 +++++++++++++++--------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/pytorch_ie/metrics/f1.py b/src/pytorch_ie/metrics/f1.py index 08561493..8b1b00b9 100644 --- a/src/pytorch_ie/metrics/f1.py +++ b/src/pytorch_ie/metrics/f1.py @@ -10,26 +10,6 @@ logger = logging.getLogger(__name__) -def eval_counts_for_layer( - document: Document, - layer: str, - annotation_filter: Optional[Callable[[Annotation], bool]] = None, - annotation_mapper: Optional[Callable[[Annotation], Hashable]] = None, -) -> Tuple[int, int, int]: - annotation_filter = annotation_filter or (lambda ann: True) - annotation_mapper = annotation_mapper or (lambda ann: ann) - predicted_annotations = { - annotation_mapper(ann) for ann in document[layer].predictions if annotation_filter(ann) - } - gold_annotations = { - annotation_mapper(ann) for ann in document[layer] if annotation_filter(ann) - } - tp = len([ann for ann in predicted_annotations & gold_annotations]) - fn = len([ann for ann in gold_annotations - predicted_annotations]) - fp = len([ann for ann in predicted_annotations - gold_annotations]) - return tp, fp, fn - - def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: return getattr(ann, label_field) in labels @@ -74,6 +54,21 @@ def __init__( def reset(self): self.counts = defaultdict(lambda: (0, 0, 0)) + def calculate_counts( + self, + document: Document, + annotation_filter: Optional[Callable[[Annotation], bool]] = None, + ) -> Tuple[int, int, int]: + annotation_filter = annotation_filter or (lambda ann: True) + predicted_annotations = { + ann for ann in document[self.layer].predictions if annotation_filter(ann) + } + gold_annotations = {ann for ann in document[self.layer] if annotation_filter(ann)} + tp = len([ann for ann in predicted_annotations & gold_annotations]) + fn = len([ann for ann in gold_annotations - predicted_annotations]) + fp = len([ann for ann in predicted_annotations - gold_annotations]) + return tp, fp, fn + def add_counts(self, counts: Tuple[int, int, int], label: str): self.counts[label] = ( self.counts[label][0] + counts[0], @@ -82,9 +77,8 @@ def add_counts(self, counts: Tuple[int, int, int], label: str): ) def _update(self, document: Document): - new_counts = eval_counts_for_layer( + new_counts = self.calculate_counts( document=document, - layer=self.layer, annotation_filter=partial( has_one_of_the_labels, label_field=self.label_field, labels=self.labels ) @@ -93,9 +87,8 @@ def _update(self, document: Document): ) self.add_counts(new_counts, label="MICRO") for label in self.labels: - new_counts = eval_counts_for_layer( + new_counts = self.calculate_counts( document=document, - layer=self.layer, annotation_filter=partial( has_this_label, label_field=self.label_field, label=label ),