Skip to content

Commit 57424f2

Browse files
fehiepsineerajprad
authored andcommitted
add diverging to extra_fields by default (#433)
1 parent a2e5990 commit 57424f2

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Let us infer the values of the unknown parameters in our model by running MCMC u
5555
We can print the summary of the MCMC run, and examine if we observed any divergences during inference:
5656

5757
```python
58-
mcmc.print_summary()
58+
>>> mcmc.print_summary()
5959

6060
mean std median 5.0% 95.0% n_eff r_hat
6161
mu 3.94 2.81 3.16 0.03 9.28 114.51 1.06
@@ -69,8 +69,6 @@ mcmc.print_summary()
6969
theta[6] 5.74 4.67 4.34 -1.92 13.25 58.42 1.05
7070
theta[7] 4.29 4.63 3.23 -2.14 12.37 342.50 1.02
7171

72-
>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))
73-
7472
Number of divergences: 139
7573
```
7674

@@ -104,8 +102,6 @@ The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates tha
104102
theta[5] 3.92 4.43 4.06 -2.41 11.09 1179.74 1.00
105103
theta[6] 5.88 4.84 5.34 -1.45 13.11 881.38 1.00
106104
theta[7] 4.63 4.86 4.64 -3.57 11.80 1065.27 1.00
107-
108-
>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))
109105

110106
Number of divergences: 0
111107
```

docs/source/diagnostics.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ HPDI
2929

3030
Summary
3131
-------
32-
.. autofunction:: numpyro.diagnostics.summary
32+
.. autofunction:: numpyro.diagnostics.print_summary

examples/neutra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpyro
1515
from numpyro import optim
1616
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoIAFNormal
17-
from numpyro.diagnostics import summary
17+
from numpyro.diagnostics import print_summary
1818
import numpyro.distributions as dist
1919
from numpyro.distributions import constraints
2020
from numpyro.infer import MCMC, NUTS, SVI
@@ -90,7 +90,7 @@ def main(args):
9090
zs = mcmc.get_samples()
9191
print("Transform samples into unwarped space...")
9292
samples = vmap(transformed_constrain_fn)(zs)
93-
summary(tree_map(lambda x: x[None, ...], samples))
93+
print_summary(tree_map(lambda x: x[None, ...], samples))
9494
samples = samples['x'].copy()
9595

9696
# make plots

numpyro/diagnostics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
'gelman_rubin',
1717
'hpdi',
1818
'split_gelman_rubin',
19-
'summary',
19+
'print_summary',
2020
]
2121

2222

@@ -212,7 +212,7 @@ def hpdi(x, prob=0.90, axis=0):
212212
return onp.concatenate([hpd_left, hpd_right], axis=axis)
213213

214214

215-
def summary(samples, prob=0.90, group_by_chain=True):
215+
def print_summary(samples, prob=0.90, group_by_chain=True):
216216
"""
217217
Prints a summary table displaying diagnostics of ``samples`` from the
218218
posterior. The diagnostics displayed are mean, standard deviation, median,
@@ -241,7 +241,7 @@ def summary(samples, prob=0.90, group_by_chain=True):
241241
header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}'
242242
columns = ['', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)),
243243
'{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat']
244-
print('\n')
244+
print()
245245
print(header_format.format(*columns))
246246

247247
# XXX: consider to expose digits, depending on user requests
@@ -264,4 +264,4 @@ def summary(samples, prob=0.90, group_by_chain=True):
264264
idx_str = '[{}]'.format(','.join(map(str, idx)))
265265
print(row_format.format(name + idx_str, mean[idx], sd[idx], median[idx],
266266
hpd[0][idx], hpd[1][idx], n_eff[idx], r_hat[idx]))
267-
print('\n')
267+
print()

numpyro/infer/mcmc.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from jax.random import PRNGKey
1414
from jax.tree_util import tree_flatten, tree_map, tree_multimap
1515

16-
from numpyro.diagnostics import summary
16+
from numpyro.diagnostics import print_summary
1717
from numpyro.infer.hmc_util import (
1818
IntegratorState,
1919
build_tree,
@@ -598,10 +598,10 @@ def _single_chain_mcmc(self, init, collect_fields=('z',), collect_warmup=False,
598598
if len(collect_fields) == 1:
599599
states = (states,)
600600
states = dict(zip(collect_fields, states))
601-
states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states)[0]) > 0 else states['z']
601+
states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states['z'])[0]) > 0 else states['z']
602602
return states
603603

604-
def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
604+
def run(self, rng_key, *args, extra_fields=('diverging',), collect_warmup=False, init_params=None, **kwargs):
605605
"""
606606
Run the MCMC samplers and collect samples.
607607
@@ -693,4 +693,7 @@ def get_extra_fields(self, group_by_chain=False):
693693
return {k: v for k, v in states.items() if k != 'z'}
694694

695695
def print_summary(self, prob=0.9):
696-
summary(self._states['z'], prob=prob)
696+
print_summary(self._states['z'], prob=prob)
697+
extra_fields = self.get_extra_fields()
698+
if 'diverging' in extra_fields:
699+
print("Number of divergences: {}".format(np.sum(extra_fields['diverging'])))

0 commit comments

Comments
 (0)