-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
numpyro.deterministic static on infer.Predictive #1772
Comments
Sorry for the breakage! Could you try to use the dev branch of lightweight mmm? I will ping a dev there for a release if it works. |
I think it's related to numpyro. The problem function is numpyro.deterministic. |
Do you mean that |
@fehiepsi saw your fix on lightweight I believe the long-term fix here is 2-fold:
If these are unfeasible for deeper reasons, then at least mention the pop trick here: https://num.pyro.ai/en/v0.2.0/utilities.html As the current behavior is a bit counterintuitive. |
I'm running into the same issue, here's a reproducible example: import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS,Predictive
from jax import random
X = np.random.normal(0, 1, size=1000)
y = 5 + 1.2*X + np.random.normal(size=1000)
def model(X,y=None):
alpha = numpyro.sample("alpha", dist.Normal(0,10))
beta = numpyro.sample("beta", dist.Normal(0,1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
with numpyro.plate("data", len(X)):
eta = numpyro.deterministic("eta", alpha + beta*X)
obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), X=X, y=y)
# Make predictions where X is a different shape
posterior_samples = mcmc.get_samples()
# posterior_samples.pop("eta") # this fixes the issues
pred_func = Predictive(model, posterior_samples=posterior_samples) traceback
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:290, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
289 else:
--> 290 return cached(config.trace_context(), *args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:283, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
281 @functools.lru_cache(max_size)
282 def cached(_, *args, **kwargs):
--> 283 return f(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
153 @cache()
154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155 return _broadcast_shapes_uncached(*shapes)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
170 if result_shape is None:
--> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
172 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)]
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[1], line 26
24 # Make predictions where X is a different shape
25 pred_func = Predictive(model, posterior_samples=mcmc.get_samples())
---> 26 preds = pred_func(random.PRNGKey(1), X=X[:200], y=None)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:1011, in Predictive.__call__(self, rng_key, *args, **kwargs)
1001 """
1002 Returns dict of samples from the predictive distribution. By default, only sample sites not
1003 contained in `posterior_samples` are returned. This can be modified by changing the
(...)
1008 :param kwargs: model kwargs.
1009 """
1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1011 return self._call_with_params(rng_key, self.params, args, kwargs)
1012 elif self.batch_ndims == 1: # batch over parameters
1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:988, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
977 posterior_samples = _predictive(
978 guide_rng_key,
979 guide,
(...)
985 model_kwargs=kwargs,
986 )
987 model = substitute(self.model, self.params)
--> 988 return _predictive(
989 rng_key,
990 model,
991 posterior_samples,
992 self._batch_shape,
993 return_sites=self.return_sites,
994 infer_discrete=self.infer_discrete,
995 parallel=self.parallel,
996 model_args=args,
997 model_kwargs=kwargs,
998 )
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:825, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
823 rng_key = rng_key.reshape(batch_shape + key_shape)
824 chunk_size = num_samples if parallel else 1
--> 825 return soft_vmap(
826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
827 )
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/util.py:419, in soft_vmap(fn, xs, batch_ndims, chunk_size)
413 xs = tree_map(
414 lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
415 xs,
416 )
417 fn = vmap(fn)
--> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
421 ys = tree_map(
422 lambda y: jnp.reshape(
423 y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
424 )[:batch_size],
425 ys,
426 )
[... skipping hidden 12 frame]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:798, in _predictive.<locals>.single_prediction(val)
789 pred_samples = _sample_posterior(
790 config_enumerate(condition(model, samples)),
791 first_available_dim,
(...)
795 **model_kwargs,
796 )
797 else:
--> 798 model_trace = trace(
799 seed(substitute(masked_model, samples), rng_key)
800 ).get_trace(*model_args, **model_kwargs)
801 pred_samples = {name: site["value"] for name, site in model_trace.items()}
803 if return_sites is not None:
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (2 times)]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[1], line 17, in model(X, y)
15 with numpyro.plate("data", len(X)):
16 eta = numpyro.deterministic("eta", alpha + beta*X)
---> 17 obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 "type": "sample",
209 "name": name,
(...)
218 "infer": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg["value"]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg)
544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
545 trailing_shape = expected_shape[overlap_idx:]
--> 546 broadcast_shape = lax.broadcast_shapes(
547 trailing_shape, tuple(dist_batch_shape)
548 )
549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape
550 msg["fn"] = msg["fn"].expand(batch_shape)
[... skipping hidden 1 frame]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
169 result_shape = _try_broadcast_shapes(shape_list)
170 if result_shape is None:
--> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
172 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)] I get that inputting samples for a deterministic site would lead to the model expecting a certain shape, but it does seem a bit awkward that the typical workflow with predictions requires some extra work if deterministics are involved. I wonder if something like this is possible? https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/mcmc.py#L714C61-L714C62 |
Hi @nikisix and @kylejcaron, really sorry for the breakage! I think a good action is to introduce |
something like that sounds reasonable. the change in behavior was probably a mistake... |
@fehiepsi @martinjankowiak should the For example, the following workflow has the same problem : guide = AutoNormal(model)
svi = SVI(model, guide, optim=numpyro.optim.Adam(0.01), loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, X=X, y=y)
params = guide.sample_posterior(random.PRNGKey(0), params=svi_result.params)
pred_func = Predictive(model, params=params, num_samples=100)
preds = pred_func(random.PRNGKey(1), X=X[:250], y=None) The solution for this seems to just including the guide and using SVI params instead, but I imagine some may be using the pattern above pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=100)
preds = pred_func(random.PRNGKey(1),X[:n_preds])['eta'] |
I think this pattern could be used with an |
@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of |
Got it that makes sense to me - seems like it'd involve just replacing the substitute call in this line and L987, but let me know if I'm missing anything. I'm happy to make an attempt at this, any name recommendations for the new effect handler? |
The substitute logic is at this line. You can change
to something like
|
nice idea with the substitute_fn, just added a PR! |
For some reason after fitting the model the numpyro.deterministic shape remains static, after trying to predict with a different shape it throws a shape error.
Example in lightweight-mmm:
This throws a size error, see:
google/lightweight_mmm#309
and
google/lightweight_mmm#308
The text was updated successfully, but these errors were encountered: