Skip to content

Commit 59de90f

Browse files
committed
Fixing issue Samples are outside the support for DiscreteUniform distribution pyro-ppl#1834
1 parent f6eb6ce commit 59de90f

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,13 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
434434
and site["fn"].has_enumerate_support
435435
and not site["is_observed"]
436436
}
437+
self._support_enumerates = {
438+
name: site["fn"].enumerate_support(False)
439+
for name, site in self._prototype_trace.items()
440+
if site["type"] == "sample"
441+
and site["fn"].has_enumerate_support
442+
and not site["is_observed"]
443+
}
437444
self._gibbs_sites = [
438445
name
439446
for name, site in self._prototype_trace.items()

numpyro/infer/mixed_hmc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from jax import grad, jacfwd, lax, random
88
from jax.flatten_util import ravel_pytree
9+
import jax
910
import jax.numpy as jnp
1011

1112
from numpyro.infer.hmc import momentum_generator
@@ -301,6 +302,11 @@ def body_fn(i, vals):
301302
adapt_state=adapt_state,
302303
)
303304

305+
z_discrete = jax.tree.map(
306+
lambda idx, support: support[idx],
307+
z_discrete,
308+
self._support_enumerates,
309+
)
304310
z = {**z_discrete, **hmc_state.z}
305311
return MixedHMCState(z, hmc_state, rng_key, accept_prob)
306312

0 commit comments

Comments
 (0)