Skip to content

Commit 5eb134d

Browse files
Add Wishart distribution. (#1779)
* Add Wishart distribution. * Reduce dimensionality for bijection tests of positive definite matrices. * Add `WishartCholesky` distribution and use it as base for `Wishart`. * Promote instead of broadcast Wishart parameters. * Assert exactly one of parameters is specified and update shape inference. * Implement `infer_shapes` for `Wishart` and `WishartCholesky`. * Add entropy for Wishart distribution. * Add sampled entropy test for distribution without `scipy` equivalent. * Simplify `logabsdet` evaluation of `scale_tril`. * Remove default `None` argument for concentration of Wishart distribution. * Add `tri_logabsdet` utility function.
1 parent f6d86c6 commit 5eb134d

File tree

9 files changed

+466
-51
lines changed

9 files changed

+466
-51
lines changed

docs/source/distributions.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ Weibull
380380
:show-inheritance:
381381
:member-order: bysource
382382

383+
Wishart
384+
^^^^^^^
385+
.. autoclass:: numpyro.distributions.continuous.Wishart
386+
:members:
387+
:undoc-members:
388+
:show-inheritance:
389+
:member-order: bysource
390+
391+
WishartCholesky
392+
^^^^^^^^^^^^^^^
393+
.. autoclass:: numpyro.distributions.continuous.WishartCholesky
394+
:members:
395+
:undoc-members:
396+
:show-inheritance:
397+
:member-order: bysource
398+
383399
ZeroSumNormal
384400
^^^^^^^^^^^^^
385401
.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal

numpyro/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
StudentT,
4848
Uniform,
4949
Weibull,
50+
Wishart,
51+
WishartCholesky,
5052
ZeroSumNormal,
5153
)
5254
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
@@ -194,6 +196,8 @@
194196
"Unit",
195197
"VonMises",
196198
"Weibull",
199+
"Wishart",
200+
"WishartCholesky",
197201
"ZeroInflatedDistribution",
198202
"ZeroInflatedPoisson",
199203
"ZeroInflatedNegativeBinomial2",

0 commit comments

Comments
 (0)