Skip to content

Commit

Permalink
infer labels also from predicted annotations (#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Jun 25, 2024
1 parent 8eaca4b commit a81b182
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/pytorch_ie/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def _update(self, document: Document):
)
self.add_counts(new_counts, label="MICRO")
if self.infer_labels:
for ann in document[self.layer]:
layer = document[self.layer]
# collect labels from gold data and predictions
for ann in list(layer) + list(layer.predictions):
label = getattr(ann, self.label_field)
if label not in self.labels:
self.labels.append(label)
Expand Down
10 changes: 8 additions & 2 deletions tests/metrics/test_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,17 @@ def test_f1_per_label_inferred(documents):
metric = F1Metric(layer="entities", labels="INFERRED")
metric(documents)
# tp, fp, fn for micro and per label
assert dict(metric.counts) == {"MICRO": (3, 2, 0), "animal": (2, 0, 0), "company": (1, 1, 0)}
assert dict(metric.counts) == {
"MICRO": (3, 2, 0),
"animal": (2, 0, 0),
"company": (1, 1, 0),
"cat": (0, 1, 0),
}
assert metric.compute() == {
"MACRO": {"f1": 0.8333333333333333, "p": 0.75, "r": 1.0},
"MACRO": {"f1": 0.5555555555555556, "p": 0.5, "r": 0.6666666666666666},
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0},
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0},
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0},
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0},
}

Expand Down

0 comments on commit a81b182

Please sign in to comment.