Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive fix when deterministic sites are present #1789

Merged
merged 9 commits into from
May 2, 2024
25 changes: 21 additions & 4 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def _predictive(
return_sites=None,
infer_discrete=False,
parallel=True,
exclude_deterministic: bool = True,
model_args=(),
model_kwargs={},
):
Expand All @@ -774,7 +775,7 @@ def _predictive(
posterior_samples,
)
prototype_trace = trace(
seed(substitute(masked_model, prototype_sample), subkey)
seed(condition(masked_model, prototype_sample), subkey)
).get_trace(*model_args, **model_kwargs)
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

Expand All @@ -795,9 +796,20 @@ def single_prediction(val):
**model_kwargs,
)
else:
model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
).get_trace(*model_args, **model_kwargs)

def _samples_wo_deterministic(msg):
return (
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
)

substituted_model = (
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
if exclude_deterministic
else substitute(masked_model, samples)
)
model_trace = trace(seed(substituted_model, rng_key)).get_trace(
*model_args, **model_kwargs
)
pred_samples = {name: site["value"] for name, site in model_trace.items()}

if return_sites is not None:
Expand Down Expand Up @@ -870,6 +882,7 @@ class Predictive(object):

+ set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters
with shapes `(num_samples x batch_size x ...)`
:param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples.

:return: dict of samples from the predictive distribution.

Expand Down Expand Up @@ -907,6 +920,7 @@ def __init__(
infer_discrete: bool = False,
parallel: bool = False,
batch_ndims: Optional[int] = None,
exclude_deterministic: bool = True,
):
if posterior_samples is None and num_samples is None:
raise ValueError(
Expand Down Expand Up @@ -967,6 +981,7 @@ def __init__(
self.parallel = parallel
self.batch_ndims = batch_ndims
self._batch_shape = batch_shape
self.exclude_deterministic = exclude_deterministic

def _call_with_params(self, rng_key, params, args, kwargs):
posterior_samples = self.posterior_samples
Expand All @@ -983,6 +998,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
exclude_deterministic=self.exclude_deterministic,
)
model = substitute(self.model, self.params)
return _predictive(
Expand All @@ -995,6 +1011,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
exclude_deterministic=self.exclude_deterministic,
)

def __call__(self, rng_key, *args, **kwargs):
Expand Down
39 changes: 39 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def model(data=None):
return model, data, true_probs


def linear_regression():
N = 800
X = dist.Normal(0, 1).sample(random.PRNGKey(0), (N,))
y = 1.5 + X * 0.7

def model(X, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0.0, 5))
beta = numpyro.sample("beta", dist.Normal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
with numpyro.plate("plate", len(X)):
mu = numpyro.deterministic("mu", alpha + X * beta)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

return model, X, y


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive(parallel):
model, data, true_probs = beta_bernoulli()
Expand All @@ -74,6 +90,29 @@ def test_predictive(parallel):
assert_allclose(obs.mean(0), true_probs, rtol=0.1)


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive_with_deterministic(parallel):
"""Tests that the default behavior when predicting from models with
deterministic sites doesn't lead to static deterministic sites in the predictive.
"""
n_preds = 400
model, X, y = linear_regression()
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
mcmc.run(random.PRNGKey(0), X=X, y=y)
samples = mcmc.get_samples()
predictive = Predictive(model, samples, parallel=parallel)
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
assert predictive_samples.keys() == {"mu", "obs"}

predictive.return_sites = ["beta", "mu", "obs"]
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
# check shapes
assert predictive_samples["mu"].shape == (100,) + X[:n_preds].shape
assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape


def test_predictive_with_guide():
data = jnp.array([1] * 8 + [0] * 2)

Expand Down
Loading