1818
1919from torchmetrics .classification .base import _ClassificationTaskWrapper
2020from torchmetrics .classification .stat_scores import BinaryStatScores , MulticlassStatScores , MultilabelStatScores
21- from torchmetrics .functional .classification .precision_recall import _precision_recall_reduce
21+ from torchmetrics .functional .classification .precision_recall import (
22+ _precision_recall_reduce ,
23+ )
2224from torchmetrics .metric import Metric
2325from torchmetrics .utilities .enums import ClassificationTask
2426from torchmetrics .utilities .imports import _MATPLOTLIB_AVAILABLE
@@ -42,7 +44,7 @@ class BinaryPrecision(BinaryStatScores):
4244
4345 Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
4446 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
45- encountered a score of 0 is returned.
47+ encountered a score of `zero_division` (0 or 1, default is 0) is returned.
4648
4749 As input to ``forward`` and ``update`` the metric accepts the following input:
4850
@@ -73,6 +75,7 @@ class BinaryPrecision(BinaryStatScores):
7375 Specifies a target value that is ignored and does not contribute to the metric calculation
7476 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
7577 Set to ``False`` for faster computations.
78+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
7679
7780 Example (preds is int tensor):
7881 >>> from torch import tensor
@@ -112,7 +115,14 @@ def compute(self) -> Tensor:
112115 """Compute metric."""
113116 tp , fp , tn , fn = self ._final_state ()
114117 return _precision_recall_reduce (
115- "precision" , tp , fp , tn , fn , average = "binary" , multidim_average = self .multidim_average
118+ "precision" ,
119+ tp ,
120+ fp ,
121+ tn ,
122+ fn ,
123+ average = "binary" ,
124+ multidim_average = self .multidim_average ,
125+ zero_division = self .zero_division ,
116126 )
117127
118128 def plot (
@@ -165,8 +175,8 @@ class MulticlassPrecision(MulticlassStatScores):
165175
166176 Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
167177 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
168- encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
169- affected in turn.
178+ encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
179+ the overall metric may therefore be affected in turn.
170180
171181 As input to ``forward`` and ``update`` the metric accepts the following input:
172182
@@ -217,6 +227,7 @@ class MulticlassPrecision(MulticlassStatScores):
217227 Specifies a target value that is ignored and does not contribute to the metric calculation
218228 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
219229 Set to ``False`` for faster computations.
230+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
220231
221232 Example (preds is int tensor):
222233 >>> from torch import tensor
@@ -269,7 +280,15 @@ def compute(self) -> Tensor:
269280 """Compute metric."""
270281 tp , fp , tn , fn = self ._final_state ()
271282 return _precision_recall_reduce (
272- "precision" , tp , fp , tn , fn , average = self .average , multidim_average = self .multidim_average , top_k = self .top_k
283+ "precision" ,
284+ tp ,
285+ fp ,
286+ tn ,
287+ fn ,
288+ average = self .average ,
289+ multidim_average = self .multidim_average ,
290+ top_k = self .top_k ,
291+ zero_division = self .zero_division ,
273292 )
274293
275294 def plot (
@@ -322,8 +341,8 @@ class MultilabelPrecision(MultilabelStatScores):
322341
323342 Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
324343 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
325- encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
326- affected in turn.
344+ encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
345+ the overall metric may therefore be affected in turn.
327346
328347 As input to ``forward`` and ``update`` the metric accepts the following input:
329348
@@ -373,6 +392,7 @@ class MultilabelPrecision(MultilabelStatScores):
373392 Specifies a target value that is ignored and does not contribute to the metric calculation
374393 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
375394 Set to ``False`` for faster computations.
395+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
376396
377397 Example (preds is int tensor):
378398 >>> from torch import tensor
@@ -423,7 +443,15 @@ def compute(self) -> Tensor:
423443 """Compute metric."""
424444 tp , fp , tn , fn = self ._final_state ()
425445 return _precision_recall_reduce (
426- "precision" , tp , fp , tn , fn , average = self .average , multidim_average = self .multidim_average , multilabel = True
446+ "precision" ,
447+ tp ,
448+ fp ,
449+ tn ,
450+ fn ,
451+ average = self .average ,
452+ multidim_average = self .multidim_average ,
453+ multilabel = True ,
454+ zero_division = self .zero_division ,
427455 )
428456
429457 def plot (
@@ -476,7 +504,7 @@ class BinaryRecall(BinaryStatScores):
476504
477505 Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
478506 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
479- encountered a score of 0 is returned.
507+ encountered a score of `zero_division` (0 or 1, default is 0) is returned.
480508
481509 As input to ``forward`` and ``update`` the metric accepts the following input:
482510
@@ -507,6 +535,7 @@ class BinaryRecall(BinaryStatScores):
507535 Specifies a target value that is ignored and does not contribute to the metric calculation
508536 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
509537 Set to ``False`` for faster computations.
538+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
510539
511540 Example (preds is int tensor):
512541 >>> from torch import tensor
@@ -546,7 +575,14 @@ def compute(self) -> Tensor:
546575 """Compute metric."""
547576 tp , fp , tn , fn = self ._final_state ()
548577 return _precision_recall_reduce (
549- "recall" , tp , fp , tn , fn , average = "binary" , multidim_average = self .multidim_average
578+ "recall" ,
579+ tp ,
580+ fp ,
581+ tn ,
582+ fn ,
583+ average = "binary" ,
584+ multidim_average = self .multidim_average ,
585+ zero_division = self .zero_division ,
550586 )
551587
552588 def plot (
@@ -599,8 +635,8 @@ class MulticlassRecall(MulticlassStatScores):
599635
600636 Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
601637 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
602- encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
603- affected in turn.
638+ encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
639+ the overall metric may therefore be affected in turn.
604640
605641 As input to ``forward`` and ``update`` the metric accepts the following input:
606642
@@ -650,6 +686,7 @@ class MulticlassRecall(MulticlassStatScores):
650686 Specifies a target value that is ignored and does not contribute to the metric calculation
651687 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
652688 Set to ``False`` for faster computations.
689+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
653690
654691 Example (preds is int tensor):
655692 >>> from torch import tensor
@@ -702,7 +739,15 @@ def compute(self) -> Tensor:
702739 """Compute metric."""
703740 tp , fp , tn , fn = self ._final_state ()
704741 return _precision_recall_reduce (
705- "recall" , tp , fp , tn , fn , average = self .average , multidim_average = self .multidim_average , top_k = self .top_k
742+ "recall" ,
743+ tp ,
744+ fp ,
745+ tn ,
746+ fn ,
747+ average = self .average ,
748+ multidim_average = self .multidim_average ,
749+ top_k = self .top_k ,
750+ zero_division = self .zero_division ,
706751 )
707752
708753 def plot (
@@ -755,8 +800,8 @@ class MultilabelRecall(MultilabelStatScores):
755800
756801 Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
757802 respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
758- encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
759- affected in turn.
803+ encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
804+ the overall metric may therefore be affected in turn.
760805
761806 As input to ``forward`` and ``update`` the metric accepts the following input:
762807
@@ -805,6 +850,7 @@ class MultilabelRecall(MultilabelStatScores):
805850 Specifies a target value that is ignored and does not contribute to the metric calculation
806851 validate_args: bool indicating if input arguments and tensors should be validated for correctness.
807852 Set to ``False`` for faster computations.
853+ zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
808854
809855 Example (preds is int tensor):
810856 >>> from torch import tensor
@@ -855,7 +901,15 @@ def compute(self) -> Tensor:
855901 """Compute metric."""
856902 tp , fp , tn , fn = self ._final_state ()
857903 return _precision_recall_reduce (
858- "recall" , tp , fp , tn , fn , average = self .average , multidim_average = self .multidim_average , multilabel = True
904+ "recall" ,
905+ tp ,
906+ fp ,
907+ tn ,
908+ fn ,
909+ average = self .average ,
910+ multidim_average = self .multidim_average ,
911+ multilabel = True ,
912+ zero_division = self .zero_division ,
859913 )
860914
861915 def plot (
0 commit comments