Skip to content

Commit

Permalink
integrate eval_counts_for_layer() into F1Metric.calculate_counts() (#318
Browse files Browse the repository at this point in the history
)

* integrate eval_counts_for_layer() into F1Metric.calculate_counts() so that derived classes can overwrite it

* remove annotation_mapper because it is never used
  • Loading branch information
ArneBinder authored Aug 25, 2023
1 parent e85e140 commit 4c26774
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions src/pytorch_ie/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@
logger = logging.getLogger(__name__)


def eval_counts_for_layer(
document: Document,
layer: str,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_mapper: Optional[Callable[[Annotation], Hashable]] = None,
) -> Tuple[int, int, int]:
annotation_filter = annotation_filter or (lambda ann: True)
annotation_mapper = annotation_mapper or (lambda ann: ann)
predicted_annotations = {
annotation_mapper(ann) for ann in document[layer].predictions if annotation_filter(ann)
}
gold_annotations = {
annotation_mapper(ann) for ann in document[layer] if annotation_filter(ann)
}
tp = len([ann for ann in predicted_annotations & gold_annotations])
fn = len([ann for ann in gold_annotations - predicted_annotations])
fp = len([ann for ann in predicted_annotations - gold_annotations])
return tp, fp, fn


def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
return getattr(ann, label_field) in labels

Expand Down Expand Up @@ -74,6 +54,21 @@ def __init__(
def reset(self):
self.counts = defaultdict(lambda: (0, 0, 0))

def calculate_counts(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
) -> Tuple[int, int, int]:
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
ann for ann in document[self.layer].predictions if annotation_filter(ann)
}
gold_annotations = {ann for ann in document[self.layer] if annotation_filter(ann)}
tp = len([ann for ann in predicted_annotations & gold_annotations])
fn = len([ann for ann in gold_annotations - predicted_annotations])
fp = len([ann for ann in predicted_annotations - gold_annotations])
return tp, fp, fn

def add_counts(self, counts: Tuple[int, int, int], label: str):
self.counts[label] = (
self.counts[label][0] + counts[0],
Expand All @@ -82,9 +77,8 @@ def add_counts(self, counts: Tuple[int, int, int], label: str):
)

def _update(self, document: Document):
new_counts = eval_counts_for_layer(
new_counts = self.calculate_counts(
document=document,
layer=self.layer,
annotation_filter=partial(
has_one_of_the_labels, label_field=self.label_field, labels=self.labels
)
Expand All @@ -93,9 +87,8 @@ def _update(self, document: Document):
)
self.add_counts(new_counts, label="MICRO")
for label in self.labels:
new_counts = eval_counts_for_layer(
new_counts = self.calculate_counts(
document=document,
layer=self.layer,
annotation_filter=partial(
has_this_label, label_field=self.label_field, label=label
),
Expand Down

0 comments on commit 4c26774

Please sign in to comment.