Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive fix when deterministic sites are present #1789

Merged
merged 9 commits into from
May 2, 2024
11 changes: 10 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,17 @@ def single_prediction(val):
**model_kwargs,
)
else:

def _samples_wo_deterministic(msg):
return (
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
)

model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
seed(
substitute(masked_model, substitute_fn=_samples_wo_deterministic),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me change line 777 to condition so that? It would be nice to add an argument like exclude_deterministic_from_posterior to Predictive to maintain two behaviors. We will pass such argument to this _predictive function to control the behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the conditional logic, I tried the name exclude_deterministic_params because it was shorter but let me know if that's insufficient

For this could you elaborate?

Could you help me change line 777 to condition so that?

Would I be changing L777 from substitute to condition?

And should I add the deterministic fix there as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, using condition there is fine because we don't substitute deterministic sites under thecondition handler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok just switched it to condition

rng_key,
)
).get_trace(*model_args, **model_kwargs)
pred_samples = {name: site["value"] for name, site in model_trace.items()}

Expand Down
39 changes: 39 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def model(data=None):
return model, data, true_probs


def linear_regression():
N = 800
X = dist.Normal(0, 1).sample(random.PRNGKey(0), (N,))
y = 1.5 + X * 0.7

def model(X, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0.0, 5))
beta = numpyro.sample("beta", dist.Normal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
with numpyro.plate("plate", len(X)):
mu = numpyro.deterministic("mu", alpha + X * beta)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

return model, X, y


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive(parallel):
model, data, true_probs = beta_bernoulli()
Expand All @@ -74,6 +90,29 @@ def test_predictive(parallel):
assert_allclose(obs.mean(0), true_probs, rtol=0.1)


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive_with_deterministic(parallel):
"""Tests that the default behavior when predicting from models with
deterministic sites doesn't lead to static deterministic sites in the predictive.
"""
n_preds = 400
model, X, y = linear_regression()
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
mcmc.run(random.PRNGKey(0), X=X, y=y)
samples = mcmc.get_samples()
predictive = Predictive(model, samples, parallel=parallel)
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
assert predictive_samples.keys() == {"mu", "obs"}

predictive.return_sites = ["beta", "mu", "obs"]
# change the input (X) shape to make sure the deterministic shape changes
predictive_samples = predictive(random.PRNGKey(1), X=X[:n_preds])
# check shapes
assert predictive_samples["mu"].shape == (100,) + X[:n_preds].shape
assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape


def test_predictive_with_guide():
data = jnp.array([1] * 8 + [0] * 2)

Expand Down
Loading