Skip to content

Commit 89ca117

Browse files
authored
Allow to use NeuTra on models with plates (#1826)
* allow to use NeuTra with plate * Fix typo in reparam.py
1 parent 2ed9f92 commit 89ca117

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

numpyro/infer/reparam.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class NeuTraReparam(Reparam):
226226
227227
# Step 2. Use trained guide in NeuTra MCMC
228228
neutra = NeuTraReparam(guide)
229-
model = netra.reparam(model)
229+
model = neutra.reparam(model)
230230
nuts = NUTS(model)
231231
# ...now use the model in HMC or NUTS...
232232
@@ -281,9 +281,15 @@ def __call__(self, name, fn, obs):
281281
compute_density = numpyro.get_mask() is not False
282282
if not self._x_unconstrained: # On first sample site.
283283
# Sample a shared latent.
284+
model_plates = {
285+
msg["name"]
286+
for msg in self.guide.prototype_trace.values()
287+
if msg["type"] == "plate"
288+
}
284289
z_unconstrained = numpyro.sample(
285290
"{}_shared_latent".format(self.guide.prefix),
286291
self.guide.get_base_dist().mask(False),
292+
infer={"block_plates": model_plates},
287293
)
288294

289295
# Differentiably transform.

numpyro/primitives.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,12 @@ def process_message(self, msg):
530530
)
531531
return
532532

533+
if (
534+
"block_plates" in msg.get("infer", {})
535+
and self.name in msg["infer"]["block_plates"]
536+
):
537+
return
538+
533539
cond_indep_stack = msg["cond_indep_stack"]
534540
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
535541
cond_indep_stack.append(frame)

test/infer/test_reparam.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from numpyro.distributions.transforms import AffineTransform, ExpTransform
1616
import numpyro.handlers as handlers
1717
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
18-
from numpyro.infer.autoguide import AutoIAFNormal
18+
from numpyro.infer.autoguide import AutoDiagonalNormal, AutoIAFNormal
1919
from numpyro.infer.reparam import (
2020
CircularReparam,
2121
ExplicitReparam,
@@ -228,6 +228,22 @@ def test_neutra_reparam_unobserved_model():
228228
reparam_model(data=None)
229229

230230

231+
def test_neutra_reparam_with_plate():
232+
def model():
233+
with numpyro.plate("N", 3, dim=-1):
234+
x = numpyro.sample("x", dist.Normal(0, 1))
235+
assert x.shape == (3,)
236+
237+
guide = AutoDiagonalNormal(model)
238+
svi = SVI(model, guide, Adam(1e-3), Trace_ELBO())
239+
svi_state = svi.init(random.PRNGKey(0))
240+
params = svi.get_params(svi_state)
241+
neutra = NeuTraReparam(guide, params)
242+
reparam_model = neutra.reparam(model)
243+
with handlers.seed(rng_seed=0):
244+
reparam_model()
245+
246+
231247
@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str)
232248
@pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, None])
233249
@pytest.mark.parametrize("dist_type", ["Normal", "StudentT"])

0 commit comments

Comments
 (0)