diff --git a/src/pytorch_ie/metrics/f1.py b/src/pytorch_ie/metrics/f1.py index 651c3dff..2efa02a4 100644 --- a/src/pytorch_ie/metrics/f1.py +++ b/src/pytorch_ie/metrics/f1.py @@ -112,7 +112,9 @@ def _update(self, document: Document): ) self.add_counts(new_counts, label="MICRO") if self.infer_labels: - for ann in document[self.layer]: + layer = document[self.layer] + # collect labels from gold data and predictions + for ann in list(layer) + list(layer.predictions): label = getattr(ann, self.label_field) if label not in self.labels: self.labels.append(label) diff --git a/tests/metrics/test_f1.py b/tests/metrics/test_f1.py index bd1ae579..8a31ed2e 100644 --- a/tests/metrics/test_f1.py +++ b/tests/metrics/test_f1.py @@ -78,11 +78,17 @@ def test_f1_per_label_inferred(documents): metric = F1Metric(layer="entities", labels="INFERRED") metric(documents) # tp, fp, fn for micro and per label - assert dict(metric.counts) == {"MICRO": (3, 2, 0), "animal": (2, 0, 0), "company": (1, 1, 0)} + assert dict(metric.counts) == { + "MICRO": (3, 2, 0), + "animal": (2, 0, 0), + "company": (1, 1, 0), + "cat": (0, 1, 0), + } assert metric.compute() == { - "MACRO": {"f1": 0.8333333333333333, "p": 0.75, "r": 1.0}, + "MACRO": {"f1": 0.5555555555555556, "p": 0.5, "r": 0.6666666666666666}, "MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0}, "animal": {"f1": 1.0, "p": 1.0, "r": 1.0}, + "cat": {"f1": 0.0, "p": 0.0, "r": 0.0}, "company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0}, }