Skip to content

Commit

Permalink
fixing docs
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Oct 8, 2023
1 parent e77af92 commit 4390ef3
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions diffusion_models/losses/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def gaussian_kl(
"""Calculate KL Divergence of 2 Gaussian distributions.
KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality).
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
Parameters
----------
Expand All @@ -30,11 +32,7 @@ def gaussian_kl(
References
----------
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
.. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
"""
return 0.5 * (torch.log(torch.abs(q_var) / torch.abs(p_var)) - 1.0 + ((p_mean-q_mean)**2)/q_var + p_var/q_var)

Expand All @@ -47,6 +45,8 @@ def log_gaussian_kl(
"""Calculate KL Divergence of 2 Gaussian distributions.
KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality) and log variances.
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
Parameters
----------
Expand All @@ -66,10 +66,6 @@ def log_gaussian_kl(
References
----------
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
.. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
"""
return 0.5 * (q_logvar - p_logvar - 1.0 + torch.exp(p_logvar - q_logvar) + ((p_mean - q_mean)**2)*torch.exp(-q_logvar))

0 comments on commit 4390ef3

Please sign in to comment.