Skip to content

Commit 823c230

Browse files
i-aki-yBordapre-commit-ci[bot]SkafteNicki
authored andcommitted
Add zero_division option to the precision, recall, f1, fbeta. (Lightning-AI#2198)
* Add support of zero_division parameter * fix overlooked * Fix type error * Fix type error * Fix missing comma * Doc fix wrong math expression * Fixed StatScores to have zero_division * fix missing zero_division arg * fix device mismatch * use scikit-learn 1.4.0 * fix scikit-learn min ver * fix for new sklearn version * fix scikit-learn requirements * fix incorrect requirements condition * fix test code to pass in multiple sklearn versions * changelog * better docstring * add jaccardindex * fix tests * skip for old sklearn versions --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent 72bb3fa commit 823c230

File tree

17 files changed

+606
-158
lines changed

17 files changed

+606
-158
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))
3434

3535

36+
- Added `zero_division` argument to selected classification metrics ([#2198](https://github.com/Lightning-AI/torchmetrics/pull/2198))
37+
38+
3639
### Changed
3740

3841
- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))

requirements/_tests.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ pyGithub ==2.3.0
1515
fire <=0.6.0
1616

1717
cloudpickle >1.3, <=3.0.0
18-
scikit-learn >=1.1.1, <1.4.0
18+
scikit-learn >=1.1.1, <1.3.0; python_version < "3.9"
19+
scikit-learn >=1.4.0, <1.5.0; python_version >= "3.9"
1920
cachier ==3.0.0

src/torchmetrics/classification/f_beta.py

Lines changed: 76 additions & 14 deletions
Large diffs are not rendered by default.

src/torchmetrics/classification/jaccard.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class BinaryJaccardIndex(BinaryConfusionMatrix):
6565
Specifies a target value that is ignored and does not contribute to the metric calculation
6666
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
6767
Set to ``False`` for faster computations.
68+
zero_division:
69+
Value to replace when there is a division by zero. Should be `0` or `1`.
6870
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
6971
7072
Example (preds is int tensor):
@@ -97,15 +99,17 @@ def __init__(
9799
threshold: float = 0.5,
98100
ignore_index: Optional[int] = None,
99101
validate_args: bool = True,
102+
zero_division: float = 0,
100103
**kwargs: Any,
101104
) -> None:
102105
super().__init__(
103106
threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs
104107
)
108+
self.zero_division = zero_division
105109

106110
def compute(self) -> Tensor:
107111
"""Compute metric."""
108-
return _jaccard_index_reduce(self.confmat, average="binary")
112+
return _jaccard_index_reduce(self.confmat, average="binary", zero_division=self.zero_division)
109113

110114
def plot( # type: ignore[override]
111115
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -187,6 +191,8 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix):
187191
188192
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
189193
Set to ``False`` for faster computations.
194+
zero_division:
195+
Value to replace when there is a division by zero. Should be `0` or `1`.
190196
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
191197
192198
Example (pred is integer tensor):
@@ -224,6 +230,7 @@ def __init__(
224230
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
225231
ignore_index: Optional[int] = None,
226232
validate_args: bool = True,
233+
zero_division: float = 0,
227234
**kwargs: Any,
228235
) -> None:
229236
super().__init__(
@@ -233,10 +240,13 @@ def __init__(
233240
_multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average)
234241
self.validate_args = validate_args
235242
self.average = average
243+
self.zero_division = zero_division
236244

237245
def compute(self) -> Tensor:
238246
"""Compute metric."""
239-
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)
247+
return _jaccard_index_reduce(
248+
self.confmat, average=self.average, ignore_index=self.ignore_index, zero_division=self.zero_division
249+
)
240250

241251
def plot( # type: ignore[override]
242252
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -319,6 +329,8 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix):
319329
320330
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
321331
Set to ``False`` for faster computations.
332+
zero_division:
333+
Value to replace when there is a division by zero. Should be `0` or `1`.
322334
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
323335
324336
Example (preds is int tensor):
@@ -354,6 +366,7 @@ def __init__(
354366
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
355367
ignore_index: Optional[int] = None,
356368
validate_args: bool = True,
369+
zero_division: float = 0,
357370
**kwargs: Any,
358371
) -> None:
359372
super().__init__(
@@ -368,10 +381,11 @@ def __init__(
368381
_multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average)
369382
self.validate_args = validate_args
370383
self.average = average
384+
self.zero_division = zero_division
371385

372386
def compute(self) -> Tensor:
373387
"""Compute metric."""
374-
return _jaccard_index_reduce(self.confmat, average=self.average)
388+
return _jaccard_index_reduce(self.confmat, average=self.average, zero_division=self.zero_division)
375389

376390
def plot( # type: ignore[override]
377391
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None

src/torchmetrics/classification/precision_recall.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
from torchmetrics.classification.base import _ClassificationTaskWrapper
2020
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
21-
from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce
21+
from torchmetrics.functional.classification.precision_recall import (
22+
_precision_recall_reduce,
23+
)
2224
from torchmetrics.metric import Metric
2325
from torchmetrics.utilities.enums import ClassificationTask
2426
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
@@ -42,7 +44,7 @@ class BinaryPrecision(BinaryStatScores):
4244
4345
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
4446
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
45-
encountered a score of 0 is returned.
47+
encountered a score of `zero_division` (0 or 1, default is 0) is returned.
4648
4749
As input to ``forward`` and ``update`` the metric accepts the following input:
4850
@@ -73,6 +75,7 @@ class BinaryPrecision(BinaryStatScores):
7375
Specifies a target value that is ignored and does not contribute to the metric calculation
7476
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
7577
Set to ``False`` for faster computations.
78+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
7679
7780
Example (preds is int tensor):
7881
>>> from torch import tensor
@@ -112,7 +115,14 @@ def compute(self) -> Tensor:
112115
"""Compute metric."""
113116
tp, fp, tn, fn = self._final_state()
114117
return _precision_recall_reduce(
115-
"precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
118+
"precision",
119+
tp,
120+
fp,
121+
tn,
122+
fn,
123+
average="binary",
124+
multidim_average=self.multidim_average,
125+
zero_division=self.zero_division,
116126
)
117127

118128
def plot(
@@ -165,8 +175,8 @@ class MulticlassPrecision(MulticlassStatScores):
165175
166176
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
167177
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
168-
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
169-
affected in turn.
178+
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
179+
the overall metric may therefore be affected in turn.
170180
171181
As input to ``forward`` and ``update`` the metric accepts the following input:
172182
@@ -217,6 +227,7 @@ class MulticlassPrecision(MulticlassStatScores):
217227
Specifies a target value that is ignored and does not contribute to the metric calculation
218228
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
219229
Set to ``False`` for faster computations.
230+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
220231
221232
Example (preds is int tensor):
222233
>>> from torch import tensor
@@ -269,7 +280,15 @@ def compute(self) -> Tensor:
269280
"""Compute metric."""
270281
tp, fp, tn, fn = self._final_state()
271282
return _precision_recall_reduce(
272-
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
283+
"precision",
284+
tp,
285+
fp,
286+
tn,
287+
fn,
288+
average=self.average,
289+
multidim_average=self.multidim_average,
290+
top_k=self.top_k,
291+
zero_division=self.zero_division,
273292
)
274293

275294
def plot(
@@ -322,8 +341,8 @@ class MultilabelPrecision(MultilabelStatScores):
322341
323342
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
324343
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
325-
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
326-
affected in turn.
344+
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
345+
the overall metric may therefore be affected in turn.
327346
328347
As input to ``forward`` and ``update`` the metric accepts the following input:
329348
@@ -373,6 +392,7 @@ class MultilabelPrecision(MultilabelStatScores):
373392
Specifies a target value that is ignored and does not contribute to the metric calculation
374393
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
375394
Set to ``False`` for faster computations.
395+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
376396
377397
Example (preds is int tensor):
378398
>>> from torch import tensor
@@ -423,7 +443,15 @@ def compute(self) -> Tensor:
423443
"""Compute metric."""
424444
tp, fp, tn, fn = self._final_state()
425445
return _precision_recall_reduce(
426-
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
446+
"precision",
447+
tp,
448+
fp,
449+
tn,
450+
fn,
451+
average=self.average,
452+
multidim_average=self.multidim_average,
453+
multilabel=True,
454+
zero_division=self.zero_division,
427455
)
428456

429457
def plot(
@@ -476,7 +504,7 @@ class BinaryRecall(BinaryStatScores):
476504
477505
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
478506
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
479-
encountered a score of 0 is returned.
507+
encountered a score of `zero_division` (0 or 1, default is 0) is returned.
480508
481509
As input to ``forward`` and ``update`` the metric accepts the following input:
482510
@@ -507,6 +535,7 @@ class BinaryRecall(BinaryStatScores):
507535
Specifies a target value that is ignored and does not contribute to the metric calculation
508536
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
509537
Set to ``False`` for faster computations.
538+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
510539
511540
Example (preds is int tensor):
512541
>>> from torch import tensor
@@ -546,7 +575,14 @@ def compute(self) -> Tensor:
546575
"""Compute metric."""
547576
tp, fp, tn, fn = self._final_state()
548577
return _precision_recall_reduce(
549-
"recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
578+
"recall",
579+
tp,
580+
fp,
581+
tn,
582+
fn,
583+
average="binary",
584+
multidim_average=self.multidim_average,
585+
zero_division=self.zero_division,
550586
)
551587

552588
def plot(
@@ -599,8 +635,8 @@ class MulticlassRecall(MulticlassStatScores):
599635
600636
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
601637
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
602-
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
603-
affected in turn.
638+
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
639+
the overall metric may therefore be affected in turn.
604640
605641
As input to ``forward`` and ``update`` the metric accepts the following input:
606642
@@ -650,6 +686,7 @@ class MulticlassRecall(MulticlassStatScores):
650686
Specifies a target value that is ignored and does not contribute to the metric calculation
651687
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
652688
Set to ``False`` for faster computations.
689+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
653690
654691
Example (preds is int tensor):
655692
>>> from torch import tensor
@@ -702,7 +739,15 @@ def compute(self) -> Tensor:
702739
"""Compute metric."""
703740
tp, fp, tn, fn = self._final_state()
704741
return _precision_recall_reduce(
705-
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
742+
"recall",
743+
tp,
744+
fp,
745+
tn,
746+
fn,
747+
average=self.average,
748+
multidim_average=self.multidim_average,
749+
top_k=self.top_k,
750+
zero_division=self.zero_division,
706751
)
707752

708753
def plot(
@@ -755,8 +800,8 @@ class MultilabelRecall(MultilabelStatScores):
755800
756801
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
757802
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
758-
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
759-
affected in turn.
803+
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
804+
the overall metric may therefore be affected in turn.
760805
761806
As input to ``forward`` and ``update`` the metric accepts the following input:
762807
@@ -805,6 +850,7 @@ class MultilabelRecall(MultilabelStatScores):
805850
Specifies a target value that is ignored and does not contribute to the metric calculation
806851
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
807852
Set to ``False`` for faster computations.
853+
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
808854
809855
Example (preds is int tensor):
810856
>>> from torch import tensor
@@ -855,7 +901,15 @@ def compute(self) -> Tensor:
855901
"""Compute metric."""
856902
tp, fp, tn, fn = self._final_state()
857903
return _precision_recall_reduce(
858-
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
904+
"recall",
905+
tp,
906+
fp,
907+
tn,
908+
fn,
909+
average=self.average,
910+
multidim_average=self.multidim_average,
911+
multilabel=True,
912+
zero_division=self.zero_division,
859913
)
860914

861915
def plot(

src/torchmetrics/classification/stat_scores.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,15 @@ def __init__(
169169
validate_args: bool = True,
170170
**kwargs: Any,
171171
) -> None:
172+
zero_division = kwargs.pop("zero_division", 0)
172173
super(_AbstractStatScores, self).__init__(**kwargs)
173174
if validate_args:
174-
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
175+
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
175176
self.threshold = threshold
176177
self.multidim_average = multidim_average
177178
self.ignore_index = ignore_index
178179
self.validate_args = validate_args
180+
self.zero_division = zero_division
179181

180182
self._create_state(size=1, multidim_average=multidim_average)
181183

@@ -313,15 +315,19 @@ def __init__(
313315
validate_args: bool = True,
314316
**kwargs: Any,
315317
) -> None:
318+
zero_division = kwargs.pop("zero_division", 0)
316319
super(_AbstractStatScores, self).__init__(**kwargs)
317320
if validate_args:
318-
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
321+
_multiclass_stat_scores_arg_validation(
322+
num_classes, top_k, average, multidim_average, ignore_index, zero_division
323+
)
319324
self.num_classes = num_classes
320325
self.top_k = top_k
321326
self.average = average
322327
self.multidim_average = multidim_average
323328
self.ignore_index = ignore_index
324329
self.validate_args = validate_args
330+
self.zero_division = zero_division
325331

326332
self._create_state(
327333
size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average
@@ -461,15 +467,19 @@ def __init__(
461467
validate_args: bool = True,
462468
**kwargs: Any,
463469
) -> None:
470+
zero_division = kwargs.pop("zero_division", 0)
464471
super(_AbstractStatScores, self).__init__(**kwargs)
465472
if validate_args:
466-
_multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
473+
_multilabel_stat_scores_arg_validation(
474+
num_labels, threshold, average, multidim_average, ignore_index, zero_division
475+
)
467476
self.num_labels = num_labels
468477
self.threshold = threshold
469478
self.average = average
470479
self.multidim_average = multidim_average
471480
self.ignore_index = ignore_index
472481
self.validate_args = validate_args
482+
self.zero_division = zero_division
473483

474484
self._create_state(size=num_labels, multidim_average=multidim_average)
475485

0 commit comments

Comments
 (0)