Skip to content

Commit f87f40e

Browse files
authored
Hm/all deterministic (#1914)
* Skip printing summary if empty. * Post-process when no sample sites present. Current post-processing behaviour skips models with only deterministic variables. Applying this change will return consistent samples regardless of whether `sample` sites are present.
1 parent 0e7bd20 commit f87f40e

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

numpyro/diagnostics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def summary(
257257

258258
summary_dict = {}
259259
for name, value in samples.items():
260+
if len(value) == 0:
261+
continue
260262
value = device_get(value)
261263
value_flat = np.reshape(value, (-1,) + value.shape[2:])
262264
mean = value_flat.mean(axis=0)
@@ -307,6 +309,8 @@ def print_summary(
307309
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
308310
}
309311
summary_dict = summary(samples, prob, group_by_chain=True)
312+
if not summary_dict:
313+
return
310314

311315
row_names = {
312316
k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"

numpyro/infer/mcmc.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def collect_and_postprocess(x):
195195
if collect_fields:
196196
fields = nested_attrgetter(*collect_fields)(x[0])
197197
fields = [fields] if len(collect_fields) == 1 else list(fields)
198-
site_values = jax.tree.flatten(fields[0])[0]
199-
if len(site_values) > 0:
200-
fields[0] = postprocess_fn(fields[0], *x[1:])
198+
fields[0] = postprocess_fn(fields[0], *x[1:])
201199

202200
if remove_sites != ():
203201
assert isinstance(fields[0], dict)
@@ -400,13 +398,27 @@ def _get_cached_fns(self):
400398
fns, key = None, None
401399
if fns is None:
402400

401+
def ensure_vmap(fn, batch_size=None):
402+
def wrapper(x):
403+
x_arrays = jax.tree.flatten(x)[0]
404+
if len(x_arrays) > 0:
405+
return vmap(fn)(x)
406+
else:
407+
assert batch_size is not None
408+
return jax.tree.map(
409+
lambda x: jnp.broadcast_to(x, (batch_size,) + jnp.shape(x)),
410+
fn(x),
411+
)
412+
413+
return wrapper
414+
403415
def _postprocess_fn(state, args, kwargs):
404416
if self.postprocess_fn is None:
405417
body_fn = self.sampler.postprocess_fn(args, kwargs)
406418
else:
407419
body_fn = self.postprocess_fn
408420
if self.chain_method == "vectorized" and self.num_chains > 1:
409-
body_fn = vmap(body_fn)
421+
body_fn = ensure_vmap(body_fn, batch_size=self.num_chains)
410422

411423
return body_fn(state)
412424

test/infer/test_mcmc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,3 +1208,29 @@ def model():
12081208
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
12091209
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
12101210
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))
1211+
1212+
1213+
def test_all_deterministic():
1214+
def model1():
1215+
numpyro.deterministic("x", 1.0)
1216+
1217+
def model2():
1218+
numpyro.deterministic("x", jnp.array([1.0, 2.0]))
1219+
1220+
num_samples = 10
1221+
shapes = {model1: (), model2: (2,)}
1222+
1223+
for model, shape in shapes.items():
1224+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
1225+
mcmc.run(random.PRNGKey(0))
1226+
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape
1227+
1228+
1229+
def test_empty_summary():
1230+
def model():
1231+
pass
1232+
1233+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
1234+
mcmc.run(random.PRNGKey(0))
1235+
1236+
mcmc.print_summary()

0 commit comments

Comments
 (0)