Skip to content

Commit e10bf59

Browse files
authored
Consider time history=0 as plate (#1443)
* Consider time history=0 as plate * lint
1 parent 9fd29ab commit e10bf59

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

numpyro/contrib/control_flow/scan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515

1616
def _subs_wrapper(subs_map, i, length, site):
17+
if site["type"] != "sample":
18+
return
1719
value = None
1820
if isinstance(subs_map, dict) and site["name"] in subs_map:
1921
value = subs_map[site["name"]]

numpyro/contrib/funsor/infer_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
162162
time_to_init_vars = defaultdict(frozenset) # PP... variables
163163
time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites
164164
sum_vars, prod_vars = frozenset(), frozenset()
165-
history = 1
165+
history = 0
166166
log_measures = {}
167167
for site in model_trace.values():
168168
if site["type"] == "sample":
@@ -186,10 +186,14 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
186186
for dim, name in dim_to_name.items():
187187
if name.startswith("_time"):
188188
time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
189-
time_to_factors[time_dim].append(log_prob_factor)
190189
history = max(
191190
history, max(_get_shift(s) for s in dim_to_name.values())
192191
)
192+
if history == 0:
193+
log_factors.append(log_prob_factor)
194+
prod_vars |= frozenset({name})
195+
else:
196+
time_to_factors[time_dim].append(log_prob_factor)
193197
time_to_init_vars[time_dim] |= frozenset(
194198
s for s in dim_to_name.values() if s.startswith("_PREV_")
195199
)

test/contrib/test_funsor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy.testing import assert_allclose
99
import pytest
1010

11+
import jax
1112
from jax import random
1213
import jax.numpy as jnp
1314

@@ -516,6 +517,29 @@ def transition_fn(carry, y):
516517
assert_allclose(actual_x_curr, expected_x_curr)
517518

518519

520+
def test_scan_enum_history_0():
521+
def model(ys):
522+
z = numpyro.sample("z", dist.Bernoulli(0.2), infer={"enumerate": "parallel"})
523+
524+
def transition_fn(c, y):
525+
numpyro.sample("y", dist.Normal(z, 1), obs=y)
526+
return None, None
527+
528+
scan(transition_fn, None, ys)
529+
530+
actual, trace = log_density(
531+
model=enum(model, first_available_dim=-1),
532+
model_args=(jnp.arange(3),),
533+
model_kwargs={},
534+
params={},
535+
)
536+
z_factor = trace["z"]["fn"].log_prob(trace["z"]["value"])
537+
prev_y_factor = trace["_PREV_y"]["fn"].log_prob(trace["_PREV_y"]["value"])
538+
y_factor = trace["y"]["fn"].log_prob(trace["y"]["value"]).sum(0)
539+
expected = jax.nn.logsumexp(z_factor + prev_y_factor + y_factor)
540+
assert_allclose(actual, expected)
541+
542+
519543
def test_missing_plate(monkeypatch):
520544
K, N = 3, 1000
521545

0 commit comments

Comments
 (0)