Skip to content

Commit fa3f731

Browse files
Do not create plates for observed sites in AutoGuide. (#1972)
1 parent d6ba568 commit fa3f731

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

numpyro/infer/autoguide.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def _setup_prototype(self, *args, **kwargs):
179179
# raise support errors early for discrete sites
180180
with helpful_support_errors(site):
181181
biject_to(site["fn"].support)
182+
# Do not create plates for observed sites because they may be subsampled
183+
# with a different size during prototype setup and training.
184+
if site["is_observed"]:
185+
continue
182186
for frame in site["cond_indep_stack"]:
183187
if frame.name in self._prototype_frames:
184188
assert frame == self._prototype_frames[frame.name], (

test/infer/test_autoguide.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,3 +1345,38 @@ def model(x):
13451345
# Check delta distributions are fine if observed.
13461346
guide = AutoDiagonalNormal(lambda: model(3.0))
13471347
numpyro.handlers.seed(guide, 9)()
1348+
1349+
1350+
@pytest.mark.parametrize(
1351+
"guide_cls",
1352+
[
1353+
AutoBNAFNormal,
1354+
AutoDAIS,
1355+
AutoDelta,
1356+
AutoDiagonalNormal,
1357+
AutoLaplaceApproximation,
1358+
AutoLowRankMultivariateNormal,
1359+
AutoMultivariateNormal,
1360+
AutoNormal,
1361+
],
1362+
)
1363+
def test_subsample(guide_cls) -> None:
1364+
def model(n: int, x: jnp.ndarray):
1365+
mu = numpyro.sample("mu", dist.Normal(0, 1))
1366+
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
1367+
with numpyro.plate("n", n, subsample_size=x.size):
1368+
numpyro.sample("x", dist.Normal(mu, sigma), obs=x)
1369+
1370+
n = 20
1371+
x = 5 + jax.random.normal(jax.random.key(1), (20,))
1372+
subset = x[: n // 2]
1373+
1374+
svi = numpyro.infer.SVI(
1375+
model,
1376+
guide_cls(model),
1377+
numpyro.optim.Adam(0.1),
1378+
numpyro.infer.Trace_ELBO(),
1379+
n=n,
1380+
)
1381+
state = svi.init(jax.random.key(2), x=x)
1382+
svi.update(state, x=subset)

0 commit comments

Comments
 (0)