2828from numpyro .util import cond , copy_docs_from , fori_collect , fori_loop , identity
2929
3030HMCState = namedtuple ('HMCState' , ['i' , 'z' , 'z_grad' , 'potential_energy' , 'num_steps' , 'accept_prob' ,
31- 'mean_accept_prob' , 'adapt_state' , 'rng' ])
31+ 'mean_accept_prob' , 'diverging' , ' adapt_state' , 'rng' ])
3232"""
3333A :func:`~collections.namedtuple` consisting of the following fields:
3434
4242 does not correspond to the proposal if it is rejected.
4343 - **mean_accept_prob** - Mean acceptance probability until current iteration
4444 during warmup adaptation or sampling (for diagnostics).
45+ - **diverging** - A boolean value to indicate whether the current trajectory is diverging.
4546 - **adapt_state** - A ``AdaptState`` namedtuple which contains adaptation information
4647 during warmup:
4748
@@ -163,6 +164,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
163164 momentum_generator = None
164165 wa_update = None
165166 wa_steps = None
167+ max_delta_energy = 1000.
166168 if algo not in {'HMC' , 'NUTS' }:
167169 raise ValueError ('`algo` must be one of `HMC` or `NUTS`.' )
168170
@@ -235,7 +237,7 @@ def init_kernel(init_params,
235237 r = momentum_generator (wa_state .mass_matrix_sqrt , rng )
236238 vv_state = vv_init (z , r )
237239 hmc_state = HMCState (0 , vv_state .z , vv_state .z_grad , vv_state .potential_energy , 0 , 0. , 0. ,
238- wa_state , rng_hmc )
240+ False , wa_state , rng_hmc )
239241
240242 # TODO: Remove; this should be the responsibility of the MCMC class.
241243 if run_warmup and num_warmup > 0 :
@@ -259,23 +261,25 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
259261 delta_energy = energy_new - energy_old
260262 delta_energy = np .where (np .isnan (delta_energy ), np .inf , delta_energy )
261263 accept_prob = np .clip (np .exp (- delta_energy ), a_max = 1.0 )
264+ diverging = delta_energy > max_delta_energy
262265 transition = random .bernoulli (rng , accept_prob )
263266 vv_state = cond (transition ,
264267 vv_state_new , lambda state : state ,
265268 vv_state , lambda state : state )
266- return vv_state , num_steps , accept_prob
269+ return vv_state , num_steps , accept_prob , diverging
267270
268271 def _nuts_next (step_size , inverse_mass_matrix , vv_state , rng ):
269272 binary_tree = build_tree (vv_update , kinetic_fn , vv_state ,
270273 inverse_mass_matrix , step_size , rng ,
274+ max_delta_energy = max_delta_energy ,
271275 max_tree_depth = max_treedepth )
272276 accept_prob = binary_tree .sum_accept_probs / binary_tree .num_proposals
273277 num_steps = binary_tree .num_proposals
274278 vv_state = IntegratorState (z = binary_tree .z_proposal ,
275279 r = vv_state .r ,
276280 potential_energy = binary_tree .z_proposal_pe ,
277281 z_grad = binary_tree .z_proposal_grad )
278- return vv_state , num_steps , accept_prob
282+ return vv_state , num_steps , accept_prob , binary_tree . diverging
279283
280284 _next = _nuts_next if algo == 'NUTS' else _hmc_next
281285
@@ -292,9 +296,9 @@ def sample_kernel(hmc_state):
292296 rng , rng_momentum , rng_transition = random .split (hmc_state .rng , 3 )
293297 r = momentum_generator (hmc_state .adapt_state .mass_matrix_sqrt , rng_momentum )
294298 vv_state = IntegratorState (hmc_state .z , r , hmc_state .potential_energy , hmc_state .z_grad )
295- vv_state , num_steps , accept_prob = _next (hmc_state .adapt_state .step_size ,
296- hmc_state .adapt_state .inverse_mass_matrix ,
297- vv_state , rng_transition )
299+ vv_state , num_steps , accept_prob , diverging = _next (hmc_state .adapt_state .step_size ,
300+ hmc_state .adapt_state .inverse_mass_matrix ,
301+ vv_state , rng_transition )
298302 # not update adapt_state after warmup phase
299303 adapt_state = cond (hmc_state .i < wa_steps ,
300304 (hmc_state .i , accept_prob , vv_state .z , hmc_state .adapt_state ),
@@ -307,7 +311,7 @@ def sample_kernel(hmc_state):
307311 mean_accept_prob = hmc_state .mean_accept_prob + (accept_prob - hmc_state .mean_accept_prob ) / n
308312
309313 return HMCState (itr , vv_state .z , vv_state .z_grad , vv_state .potential_energy , num_steps ,
310- accept_prob , mean_accept_prob , adapt_state , rng )
314+ accept_prob , mean_accept_prob , diverging , adapt_state , rng )
311315
312316 # Make `init_kernel` and `sample_kernel` visible from the global scope once
313317 # `hmc` is called for sphinx doc generation.
0 commit comments