|
13 | 13 | from jax.random import PRNGKey |
14 | 14 | from jax.tree_util import tree_flatten, tree_map, tree_multimap |
15 | 15 |
|
16 | | -from numpyro.diagnostics import summary |
| 16 | +from numpyro.diagnostics import print_summary |
17 | 17 | from numpyro.infer.hmc_util import ( |
18 | 18 | IntegratorState, |
19 | 19 | build_tree, |
@@ -598,10 +598,10 @@ def _single_chain_mcmc(self, init, collect_fields=('z',), collect_warmup=False, |
598 | 598 | if len(collect_fields) == 1: |
599 | 599 | states = (states,) |
600 | 600 | states = dict(zip(collect_fields, states)) |
601 | | - states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states)[0]) > 0 else states['z'] |
| 601 | + states['z'] = vmap(constrain_fn)(states['z']) if len(tree_flatten(states['z'])[0]) > 0 else states['z'] |
602 | 602 | return states |
603 | 603 |
|
604 | | - def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs): |
| 604 | + def run(self, rng_key, *args, extra_fields=('diverging',), collect_warmup=False, init_params=None, **kwargs): |
605 | 605 | """ |
606 | 606 | Run the MCMC samplers and collect samples. |
607 | 607 |
|
@@ -693,4 +693,7 @@ def get_extra_fields(self, group_by_chain=False): |
693 | 693 | return {k: v for k, v in states.items() if k != 'z'} |
694 | 694 |
|
695 | 695 | def print_summary(self, prob=0.9): |
696 | | - summary(self._states['z'], prob=prob) |
| 696 | + print_summary(self._states['z'], prob=prob) |
| 697 | + extra_fields = self.get_extra_fields() |
| 698 | + if 'diverging' in extra_fields: |
| 699 | + print("Number of divergences: {}".format(np.sum(extra_fields['diverging']))) |
0 commit comments