Skip to content

Commit 1a5c16a

Browse files
fehiepsineerajprad
authored andcommitted
add diverging state (#330)
1 parent 2ddb42c commit 1a5c16a

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

numpyro/mcmc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity
2929

3030
HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
31-
'mean_accept_prob', 'adapt_state', 'rng'])
31+
'mean_accept_prob', 'diverging', 'adapt_state', 'rng'])
3232
"""
3333
A :func:`~collections.namedtuple` consisting of the following fields:
3434
@@ -42,6 +42,7 @@
4242
does not correspond to the proposal if it is rejected.
4343
- **mean_accept_prob** - Mean acceptance probability until current iteration
4444
during warmup adaptation or sampling (for diagnostics).
45+
- **diverging** - A boolean value to indicate whether the current trajectory is diverging.
4546
- **adapt_state** - A ``AdaptState`` namedtuple which contains adaptation information
4647
during warmup:
4748
@@ -163,6 +164,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
163164
momentum_generator = None
164165
wa_update = None
165166
wa_steps = None
167+
max_delta_energy = 1000.
166168
if algo not in {'HMC', 'NUTS'}:
167169
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
168170

@@ -235,7 +237,7 @@ def init_kernel(init_params,
235237
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
236238
vv_state = vv_init(z, r)
237239
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
238-
wa_state, rng_hmc)
240+
False, wa_state, rng_hmc)
239241

240242
# TODO: Remove; this should be the responsibility of the MCMC class.
241243
if run_warmup and num_warmup > 0:
@@ -259,23 +261,25 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
259261
delta_energy = energy_new - energy_old
260262
delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
261263
accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
264+
diverging = delta_energy > max_delta_energy
262265
transition = random.bernoulli(rng, accept_prob)
263266
vv_state = cond(transition,
264267
vv_state_new, lambda state: state,
265268
vv_state, lambda state: state)
266-
return vv_state, num_steps, accept_prob
269+
return vv_state, num_steps, accept_prob, diverging
267270

268271
def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):
269272
binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
270273
inverse_mass_matrix, step_size, rng,
274+
max_delta_energy=max_delta_energy,
271275
max_tree_depth=max_treedepth)
272276
accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
273277
num_steps = binary_tree.num_proposals
274278
vv_state = IntegratorState(z=binary_tree.z_proposal,
275279
r=vv_state.r,
276280
potential_energy=binary_tree.z_proposal_pe,
277281
z_grad=binary_tree.z_proposal_grad)
278-
return vv_state, num_steps, accept_prob
282+
return vv_state, num_steps, accept_prob, binary_tree.diverging
279283

280284
_next = _nuts_next if algo == 'NUTS' else _hmc_next
281285

@@ -292,9 +296,9 @@ def sample_kernel(hmc_state):
292296
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
293297
r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum)
294298
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
295-
vv_state, num_steps, accept_prob = _next(hmc_state.adapt_state.step_size,
296-
hmc_state.adapt_state.inverse_mass_matrix,
297-
vv_state, rng_transition)
299+
vv_state, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
300+
hmc_state.adapt_state.inverse_mass_matrix,
301+
vv_state, rng_transition)
298302
# not update adapt_state after warmup phase
299303
adapt_state = cond(hmc_state.i < wa_steps,
300304
(hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state),
@@ -307,7 +311,7 @@ def sample_kernel(hmc_state):
307311
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
308312

309313
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
310-
accept_prob, mean_accept_prob, adapt_state, rng)
314+
accept_prob, mean_accept_prob, diverging, adapt_state, rng)
311315

312316
# Make `init_kernel` and `sample_kernel` visible from the global scope once
313317
# `hmc` is called for sphinx doc generation.

test/test_mcmc_interface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,26 @@ def model(data):
243243
assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
244244

245245

246+
@pytest.mark.parametrize('kernel_cls', [HMC, NUTS])
247+
@pytest.mark.parametrize('adapt_step_size', [True, False])
248+
def test_diverging(kernel_cls, adapt_step_size):
249+
data = random.normal(random.PRNGKey(0), (1000,))
250+
251+
def model(data):
252+
loc = numpyro.sample('loc', dist.Normal(0., 1.))
253+
numpyro.sample('obs', dist.Normal(loc, 1), obs=data)
254+
255+
kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
256+
num_warmup = num_samples = 1000
257+
mcmc = MCMC(kernel, num_warmup, num_samples)
258+
mcmc.run(random.PRNGKey(1), data, collect_fields=('z', 'diverging'), collect_warmup=True)
259+
num_divergences = mcmc.get_samples()[1].sum()
260+
if adapt_step_size:
261+
assert num_divergences <= num_warmup
262+
else:
263+
assert_allclose(num_divergences, num_warmup + num_samples)
264+
265+
246266
@pytest.mark.parametrize('use_init_params', [False, True])
247267
@pytest.mark.parametrize('chain_method', ['parallel', 'sequential', 'vectorized'])
248268
@pytest.mark.filterwarnings("ignore:There are not enough devices:UserWarning")

0 commit comments

Comments
 (0)