Skip to content

Commit 2e31678

Browse files
fehiepsineerajprad
authored andcommitted
add missing docs and remove some TODOs (#529)
1 parent e6b2027 commit 2e31678

File tree

8 files changed

+32
-5
lines changed

8 files changed

+32
-5
lines changed

docs/source/autoguide.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ AutoContinuous
1111
:show-inheritance:
1212
:member-order: bysource
1313

14+
AutoBNAFNormal
15+
--------------
16+
.. autoclass:: numpyro.contrib.autoguide.AutoBNAFNormal
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:
20+
:member-order: bysource
21+
1422
AutoDiagonalNormal
1523
------------------
1624
.. autoclass:: numpyro.contrib.autoguide.AutoDiagonalNormal

docs/source/distributions.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,11 @@ InverseAutoregressiveTransform
563563
:undoc-members:
564564
:show-inheritance:
565565
:member-order: bysource
566+
567+
BlockNeuralAutoregressiveTransform
568+
----------------------------------
569+
.. autoclass:: numpyro.distributions.flows.BlockNeuralAutoregressiveTransform
570+
:members:
571+
:undoc-members:
572+
:show-inheritance:
573+
:member-order: bysource

docs/source/handlers.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ condition
1919
:show-inheritance:
2020
:member-order: bysource
2121

22+
mask
23+
----
24+
.. autoclass:: numpyro.handlers.mask
25+
:members:
26+
:undoc-members:
27+
:show-inheritance:
28+
:member-order: bysource
29+
2230
replay
2331
------
2432
.. autoclass:: numpyro.handlers.replay

docs/source/primitives.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ plate
1515
-----
1616
.. autoclass:: numpyro.primitives.plate
1717

18+
deterministic
19+
-------------
20+
.. autofunction:: numpyro.primitives.deterministic
21+
1822
factor
1923
------
2024
.. autofunction:: numpyro.primitives.factor

numpyro/distributions/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def categorical(key, p, shape=()):
4646
return _categorical(key, p, shape)
4747

4848

49-
# TODO: use this sampler in CategoricalLogits
5049
# TODO: drop this for the next JAX release, see https://github.com/google/jax/pull/1855
5150
def categorical_logits(key, logits, shape=()):
5251
shape = shape or logits.shape[:-1]

numpyro/infer/mcmc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,9 @@ class SA(MCMCKernel):
810810
subset of approximate posterior samples of size num_chains x num_samples
811811
instead of num_chains x num_samples x adapt_state_size.
812812
813+
.. note:: We recommend to use this kernel with `progress_bar=False` in :class:`MCMC`
814+
to reduce JAX's dispatch overhead.
815+
813816
**References:**
814817
815818
1. *Sample Adaptive MCMC* (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc),

numpyro/primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def deterministic(name, value):
151151
"""
152152
Used to designate deterministic sites in the model. Note that most effect
153153
handlers will not operate on deterministic sites (except
154-
:function:`~numpyro.handlers.trace`), so deterministic sites should be
154+
:func:`~numpyro.handlers.trace`), so deterministic sites should be
155155
side-effect free. The use case for deterministic nodes is to record any
156156
values in the model execution trace.
157157

test/test_hmc_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,6 @@ def fn(vv_state):
394394
assert tree.num_proposals > 10
395395

396396

397-
# TODO: raise this warning issue upstream, the issue is at this line
398-
# https://github.com/google/jax/blob/master/jax/numpy/lax_numpy.py#L2732
399-
@pytest.mark.filterwarnings('ignore:Explicitly requested dtype float64')
400397
@pytest.mark.parametrize('method', [consensus, parametric_draws])
401398
@pytest.mark.parametrize('diagonal', [True, False])
402399
def test_gaussian_subposterior(method, diagonal):

0 commit comments

Comments
 (0)