Skip to content

Commit 5692d2d

Browse files
authored
BF: fix ensemble mcmc run after warmup (#1918)
1 parent 07e4c9b commit 5692d2d

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

numpyro/infer/mcmc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,11 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
657657

658658
if self._warmup_state is not None:
659659
self._set_collection_params(0, self.num_samples, self.num_samples, "sample")
660-
init_state = self._warmup_state._replace(rng_key=rng_key)
660+
661+
if self.sampler.is_ensemble_kernel:
662+
init_state = self._warmup_state._replace(rng_key=rng_key[0])
663+
else:
664+
init_state = self._warmup_state._replace(rng_key=rng_key)
661665

662666
if init_params is not None and self.num_chains > 1:
663667
prototype_init_val = jax.tree.flatten(init_params)[0][0]

test/infer/test_ensemble_mcmc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,20 @@ def test_multirun(kernel_cls):
9696
)
9797
mcmc.run(random.PRNGKey(2), labels)
9898
mcmc.run(random.PRNGKey(3), labels)
99+
100+
101+
@pytest.mark.parametrize("kernel_cls", [AIES, ESS])
102+
def test_warmup(kernel_cls):
103+
n_chains = 10
104+
kernel = kernel_cls(model)
105+
106+
mcmc = MCMC(
107+
kernel,
108+
num_warmup=10,
109+
num_samples=10,
110+
progress_bar=False,
111+
num_chains=n_chains,
112+
chain_method="vectorized",
113+
)
114+
mcmc.warmup(random.PRNGKey(2), labels)
115+
mcmc.run(random.PRNGKey(3), labels)

0 commit comments

Comments
 (0)