Skip to content

Commit 0071f58

Browse files
authored
Do not rename plates under scope handler (#1338)
* only allow sample and deterministic by default * change argument to skip_param * fix lint * fix isort * fix hsgp example * still allow param * add a hide_types argument for scope * hide plate in scan enum init scope * clean typo:
1 parent 0bff074 commit 0071f58

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

numpyro/contrib/control_flow/scan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def body_fn(wrapped_carry, x, prefix=None):
145145

146146
if init:
147147
# handler the name to match the pattern of sakkar_bilmes product
148-
with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""):
148+
with handlers.scope(
149+
prefix="_PREV_" * (unroll_steps - i), divider="", hide_types=["plate"]
150+
):
149151
new_carry, y = config_enumerate(seeded_fn)(carry, x)
150152
trace = {}
151153
else:

numpyro/handlers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,18 +630,20 @@ class scope(Messenger):
630630
:param fn: Python callable with NumPyro primitives.
631631
:param str prefix: a string to prepend to sample names
632632
:param str divider: a string to join the prefix and sample name; default to `'/'`
633+
:param list hide_types: an optional list of side types to skip renaming.
633634
"""
634635

635-
def __init__(self, fn=None, prefix="", divider="/"):
636+
def __init__(self, fn=None, prefix="", divider="/", *, hide_types=None):
636637
self.prefix = prefix
637638
self.divider = divider
639+
self.hide_types = [] if hide_types is None else hide_types
638640
super().__init__(fn)
639641

640642
def process_message(self, msg):
641-
if msg.get("name"):
643+
if msg.get("name") and msg["type"] not in self.hide_types:
642644
msg["name"] = f"{self.prefix}{self.divider}{msg['name']}"
643645

644-
if msg.get("cond_indep_stack"):
646+
if msg.get("cond_indep_stack") and "plate" not in self.hide_types:
645647
msg["cond_indep_stack"] = [
646648
CondIndepStackFrame(
647649
f"{self.prefix}{self.divider}{i.name}", i.dim, i.size

test/test_handlers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,18 +578,18 @@ def test_block():
578578

579579
def test_scope():
580580
def fn():
581-
return numpyro.sample("x", dist.Normal())
581+
with numpyro.plate("N", 10):
582+
return numpyro.sample("x", dist.Normal())
582583

583584
with handlers.trace() as trace:
584585
with handlers.seed(rng_seed=1):
585586
with handlers.scope(prefix="a"):
586587
fn()
587588
with handlers.scope(prefix="b"):
588-
with handlers.scope(prefix="a"):
589+
with handlers.scope(prefix="a", hide_types=["plate"]):
589590
fn()
590591

591-
assert "a/x" in trace
592-
assert "b/a/x" in trace
592+
assert set(trace) == {"a/x", "b/a/x", "a/N", "b/N"}
593593

594594

595595
def test_scope_frames():

0 commit comments

Comments
 (0)