Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/torchmetrics/functional/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_xlogy
from torchmetrics.utilities.prints import rank_zero_warn


def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]:
Expand Down Expand Up @@ -91,6 +92,14 @@ def kl_divergence(
over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence
is a non-symmetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`.

.. warning::
The input order and naming in metric ``kl_divergence`` is set to be deprecated in v1.4 and changed in v1.5.
Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric.
Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric.
Thus, ``kl_divergence(p, q)`` will equal ``kl_divergence(target=q, preds=p)`` in the future to be consistent
with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and
from v1.5 the two old arguments will be removed.

Args:
p: data distribution with shape ``[N, d]``
q: prior or approximate distribution with shape ``[N, d]``
Expand All @@ -111,5 +120,14 @@ def kl_divergence(
tensor(0.0853)

"""
rank_zero_warn(
"The input order and naming in metric `kl_divergence` is set to be deprecated in v1.4 and changed in v1.5."
"Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric."
"Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric."
"Thus, `kl_divergence(p, q)` will equal `kl_divergence(target=q, preds=p)` in the future to be consistent with"
" the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5"
" the two old arguments will be removed.",
DeprecationWarning,
)
measures, total = _kld_update(p, q, log_prob)
return _kld_compute(measures, total, reduction)
18 changes: 18 additions & 0 deletions src/torchmetrics/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["KLDivergence.plot"]
Expand All @@ -46,6 +47,14 @@ class KLDivergence(Metric):

- ``kl_divergence`` (:class:`~torch.Tensor`): A tensor with the KL divergence

.. warning::
The input order and naming in metric ``KLDivergence`` is set to be deprecated in v1.4 and changed in v1.5.
Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric.
Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric.
Thus, ``KLDivergence(p, q)`` will equal ``KLDivergence(target=q, preds=p)`` in the future to be consistent
with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and
from v1.5 the two old arguments will be removed.

Args:
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
will normalize to make sure the distributes sum to 1.
Expand Down Expand Up @@ -92,6 +101,15 @@ def __init__(
reduction: Literal["mean", "sum", "none", None] = "mean",
**kwargs: Any,
) -> None:
rank_zero_warn(
"The input order and naming in metric `KLDivergence` is set to be deprecated in v1.4 and changed in v1.5."
"Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric."
"Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric."
"Thus, `KLDivergence(p, q)` will equal `KLDivergence(target=q, preds=p)` in the future to be consistent"
" with the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and"
" from v1.5 the two old arguments will be removed.",
DeprecationWarning,
)
super().__init__(**kwargs)
if not isinstance(log_prob, bool):
raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}")
Expand Down
16 changes: 16 additions & 0 deletions tests/unittests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch
from torchmetrics.functional.regression import kl_divergence
from torchmetrics.regression import KLDivergence


def test_deprecated_kl_divergence_input_order():
"""Ensure that the deprecated input order for kl_divergence raises a warning."""
preds = torch.randn(10, 2)
target = torch.randn(10, 2)

with pytest.deprecated_call(match="The input order and naming in metric `kl_divergence` is set to be deprecated.*"):
kl_divergence(preds, target)

with pytest.deprecated_call(match="The input order and naming in metric `KLDivergence` is set to be deprecated.*"):
KLDivergence()