Skip to content

Commit 7c3ec50

Browse files
kylejcaronkylejcaron
andauthored
Predictive fix when deterministic sites are present (#1789)
* added custom effect handler for predictive * added test, fixed predictive_substitute * fixed typo, removed unneeded custom substitute calls * removed custom effect handler, improved readability * reverted formatting of imports * added conditional arg for handling deterministic sites to predictive * changed arg name to exclude_deterministic * updated exclude_deterministic description * changed substitute to condition in infer_discrete _predctive workflow --------- Co-authored-by: kylejcaron <[email protected]>
1 parent 0fd8c2e commit 7c3ec50

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

numpyro/infer/util.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def _predictive(
761761
return_sites=None,
762762
infer_discrete=False,
763763
parallel=True,
764+
exclude_deterministic: bool = True,
764765
model_args=(),
765766
model_kwargs={},
766767
):
@@ -774,7 +775,7 @@ def _predictive(
774775
posterior_samples,
775776
)
776777
prototype_trace = trace(
777-
seed(substitute(masked_model, prototype_sample), subkey)
778+
seed(condition(masked_model, prototype_sample), subkey)
778779
).get_trace(*model_args, **model_kwargs)
779780
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1
780781

@@ -795,9 +796,20 @@ def single_prediction(val):
795796
**model_kwargs,
796797
)
797798
else:
798-
model_trace = trace(
799-
seed(substitute(masked_model, samples), rng_key)
800-
).get_trace(*model_args, **model_kwargs)
799+
800+
def _samples_wo_deterministic(msg):
801+
return (
802+
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
803+
)
804+
805+
substituted_model = (
806+
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
807+
if exclude_deterministic
808+
else substitute(masked_model, samples)
809+
)
810+
model_trace = trace(seed(substituted_model, rng_key)).get_trace(
811+
*model_args, **model_kwargs
812+
)
801813
pred_samples = {name: site["value"] for name, site in model_trace.items()}
802814

803815
if return_sites is not None:
@@ -870,6 +882,7 @@ class Predictive(object):
870882
871883
+ set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters
872884
with shapes `(num_samples x batch_size x ...)`
885+
:param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples.
873886
874887
:return: dict of samples from the predictive distribution.
875888
@@ -907,6 +920,7 @@ def __init__(
907920
infer_discrete: bool = False,
908921
parallel: bool = False,
909922
batch_ndims: Optional[int] = None,
923+
exclude_deterministic: bool = True,
910924
):
911925
if posterior_samples is None and num_samples is None:
912926
raise ValueError(
@@ -967,6 +981,7 @@ def __init__(
967981
self.parallel = parallel
968982
self.batch_ndims = batch_ndims
969983
self._batch_shape = batch_shape
984+
self.exclude_deterministic = exclude_deterministic
970985

971986
def _call_with_params(self, rng_key, params, args, kwargs):
972987
posterior_samples = self.posterior_samples
@@ -983,6 +998,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
983998
parallel=self.parallel,
984999
model_args=args,
9851000
model_kwargs=kwargs,
1001+
exclude_deterministic=self.exclude_deterministic,
9861002
)
9871003
model = substitute(self.model, self.params)
9881004
return _predictive(
@@ -995,6 +1011,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
9951011
parallel=self.parallel,
9961012
model_args=args,
9971013
model_kwargs=kwargs,
1014+
exclude_deterministic=self.exclude_deterministic,
9981015
)
9991016

10001017
def __call__(self, rng_key, *args, **kwargs):

test/infer/test_infer_util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ def model(data=None):
5353
return model, data, true_probs
5454

5555

56+
def linear_regression():
57+
N = 800
58+
X = dist.Normal(0, 1).sample(random.PRNGKey(0), (N,))
59+
y = 1.5 + X * 0.7
60+
61+
def model(X, y=None):
62+
alpha = numpyro.sample("alpha", dist.Normal(0.0, 5))
63+
beta = numpyro.sample("beta", dist.Normal(0.0, 1.0))
64+
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
65+
with numpyro.plate("plate", len(X)):
66+
mu = numpyro.deterministic("mu", alpha + X * beta)
67+
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
68+
69+
return model, X, y
70+
71+
5672
@pytest.mark.parametrize("parallel", [True, False])
5773
def test_predictive(parallel):
5874
model, data, true_probs = beta_bernoulli()
@@ -74,6 +90,29 @@ def test_predictive(parallel):
7490
assert_allclose(obs.mean(0), true_probs, rtol=0.1)
7591

7692

93+
@pytest.mark.parametrize("parallel", [True, False])
94+
def test_predictive_with_deterministic(parallel):
95+
"""Tests that the default behavior when predicting from models with
96+
deterministic sites doesn't lead to static deterministic sites in the predictive.
97+
"""
98+
n_preds = 400
99+
model, X, y = linear_regression()
100+
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
101+
mcmc.run(random.PRNGKey(0), X=X, y=y)
102+
samples = mcmc.get_samples()
103+
predictive = Predictive(model, samples, parallel=parallel)
104+
# change the input (X) shape to make sure the deterministic shape changes
105+
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
106+
assert predictive_samples.keys() == {"mu", "obs"}
107+
108+
predictive.return_sites = ["beta", "mu", "obs"]
109+
# change the input (X) shape to make sure the deterministic shape changes
110+
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
111+
# check shapes
112+
assert predictive_samples["mu"].shape == (100,) + X[:n_preds].shape
113+
assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape
114+
115+
77116
def test_predictive_with_guide():
78117
data = jnp.array([1] * 8 + [0] * 2)
79118

0 commit comments

Comments
 (0)