Skip to content

Commit 2e72f2f

Browse files
authored
Enhance LocScaleReparam noncenter dependency (#1384)
* Enhance LocScaleReparam noncenter dependency * fix lint * Use id in the first place * change tree_multimap to tree_map * Fix lint * Fix lint * Fix failing tests due to FutureWarning tree_multimap * Also support ndarray in LocScaleReparam
1 parent 0071f58 commit 2e72f2f

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

numpyro/infer/reparam.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
from abc import ABC, abstractmethod
55
import math
66

7+
import numpy as np
8+
9+
import jax
710
import jax.numpy as jnp
811

912
import numpyro
1013
import numpyro.distributions as dist
1114
from numpyro.distributions import biject_to, constraints
1215
from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost
1316
from numpyro.infer.autoguide import AutoContinuous
17+
from numpyro.util import not_jax_tracer
1418

1519

1620
class Reparam(ABC):
@@ -78,11 +82,19 @@ class LocScaleReparam(Reparam):
7882
"""
7983

8084
def __init__(self, centered=None, shape_params=()):
81-
assert centered is None or isinstance(centered, (int, float))
85+
assert centered is None or isinstance(
86+
centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer)
87+
)
8288
assert isinstance(shape_params, (tuple, list))
8389
assert all(isinstance(name, str) for name in shape_params)
84-
if isinstance(centered, (int, float)):
85-
assert 0 <= centered and centered <= 1
90+
if centered is not None:
91+
is_valid = constraints.unit_interval.check(centered)
92+
if not_jax_tracer(is_valid):
93+
if not np.all(is_valid):
94+
raise ValueError(
95+
"`centered` argument does not satisfy `0 <= centered <= 1`."
96+
)
97+
8698
self.centered = centered
8799
self.shape_params = shape_params
88100

@@ -102,8 +114,12 @@ def __call__(self, name, fn, obs):
102114
jnp.full(event_shape, 0.5),
103115
constraint=constraints.unit_interval,
104116
)
105-
params["loc"] = fn.loc * centered
106-
params["scale"] = fn.scale**centered
117+
if isinstance(centered, (int, float, np.generic)) and centered == 0.0:
118+
params["loc"] = jnp.zeros_like(fn.loc)
119+
params["scale"] = jnp.ones_like(fn.scale)
120+
else:
121+
params["loc"] = fn.loc * centered
122+
params["scale"] = fn.scale**centered
107123
decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim)
108124

109125
# Draw decentered noise.

test/infer/test_reparam.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def test_loc_scale(dist_type, centered, shape, event_dim):
233233
loc = np.random.uniform(-1.0, 1.0, shape)
234234
scale = np.random.uniform(0.5, 1.5, shape)
235235
event_dim = min(event_dim, len(shape))
236+
if event_dim == 1 and centered is not None:
237+
centered = jnp.broadcast_to(centered, shape[-event_dim:])
236238

237239
def model(loc, scale):
238240
with numpyro.plate_stack("plates", shape[: len(shape) - event_dim]):
@@ -272,6 +274,22 @@ def get_actual_probe(loc, scale):
272274
assert_allclose(actual_grad[1], expected_grad[1], atol=0.05) # scale grad
273275

274276

277+
@pytest.mark.parametrize("centered", [-0.3, 10.0, np.array([0.1, -2.0])])
278+
def test_loc_scale_centered_invalid(centered):
279+
N = 10
280+
loc = np.random.uniform(-1.0, 1.0, size=(N, 2))
281+
scale = np.random.uniform(0.5, 1.5, size=(N, 2))
282+
283+
def model(loc, scale):
284+
with numpyro.plate("particles", N):
285+
numpyro.sample("x", dist.Normal(loc, scale).to_event(1))
286+
287+
with pytest.raises(ValueError, match=".*does not satisfy.*"):
288+
with handlers.reparam(config=LocScaleReparam(centered)):
289+
with handlers.seed(rng_seed=10):
290+
model(loc, scale)
291+
292+
275293
@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str)
276294
@pytest.mark.parametrize("dim", [2, 3, 4])
277295
def test_projected_normal(shape, dim):

0 commit comments

Comments
 (0)