Skip to content

Commit f997da2

Browse files
authored
fixed haiku for ecs (#1750)
1 parent e6c187c commit f997da2

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

numpyro/contrib/ecs_proxies.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def construct_proxy_fn(
116116
num_blocks=1,
117117
):
118118
ref_params = {
119-
name: biject_to(prototype_trace[name]["fn"].support).inv(value)
119+
name: (
120+
biject_to(prototype_trace[name]["fn"].support).inv(value)
121+
if prototype_trace[name]["type"] == "sample"
122+
else value
123+
)
120124
for name, value in reference_params.items()
121125
}
122126

@@ -131,7 +135,11 @@ def log_likelihood(params_flat, subsample_indices=None):
131135
with warnings.catch_warnings():
132136
warnings.simplefilter("ignore")
133137
params = {
134-
name: biject_to(prototype_trace[name]["fn"].support)(value)
138+
name: (
139+
biject_to(prototype_trace[name]["fn"].support)(value)
140+
if prototype_trace[name]["type"] == "sample"
141+
else value
142+
)
135143
for name, value in params.items()
136144
}
137145
with (
@@ -167,9 +175,7 @@ def log_likelihood_sum(params_flat, subsample_indices=None):
167175
elif 1:
168176
TPState = TaylorOneProxyState
169177
else:
170-
raise ValueError(
171-
"Taylor proxy only defined for first and second degree."
172-
)
178+
raise ValueError("Taylor proxy only defined for first and second degree.")
173179

174180
# those stats are dict keyed by subsample names
175181
ref_sum_log_lik = log_likelihood_sum(ref_params_flat)

test/contrib/test_esc_proxies.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55

66
from jax import numpy as jnp, random
77

8+
from numpyro import plate, prng_key, sample
89
from numpyro.contrib.ecs_proxies import block_update
10+
from numpyro.contrib.module import random_haiku_module
11+
from numpyro.distributions import Cauchy, Normal
12+
from numpyro.handlers import seed
13+
from numpyro.infer import HMC, HMCECS, MCMC, SVI, Trace_ELBO
14+
from numpyro.infer.autoguide import AutoDelta
15+
from numpyro.optim import Adam
916

1017

1118
@pytest.mark.parametrize("num_blocks", [1, 2, 50, 100])
@@ -26,3 +33,42 @@ def test_block_update_partitioning(num_blocks):
2633
)
2734

2835
assert gibbs_state == new_gibbs_state
36+
37+
38+
def test_haiku_compatiable():
39+
try:
40+
import haiku as hk # noqa: F401
41+
42+
data_points = 6
43+
x_dim = 4
44+
45+
def model(x, y):
46+
net = random_haiku_module(
47+
"net",
48+
hk.transform(lambda x: hk.Linear(1)(x)),
49+
prior={"linear.b": Cauchy(), "linear.w": Normal()},
50+
input_shape=(1, x_dim),
51+
)
52+
53+
with plate("data", data_points, subsample_size=2) as idx:
54+
yb = y[idx]
55+
xb = x[idx]
56+
sample("y", Normal(net(xb).squeeze()), obs=yb)
57+
58+
x = jnp.ones((data_points, x_dim))
59+
y = jnp.array((data_points, 0))
60+
61+
with seed(rng_seed=0):
62+
svi = SVI(model, AutoDelta(model), Adam(step_size=1e-3), Trace_ELBO())
63+
svi_result = svi.run(prng_key(), 1, x, y)
64+
ref_params = {
65+
k.removesuffix("_auto_loc"): v for k, v in svi_result.params.items()
66+
}
67+
68+
proxy = HMCECS.taylor_proxy(ref_params, degree=2)
69+
kernel = HMCECS(HMC(model), num_blocks=2, proxy=proxy)
70+
71+
mcmc = MCMC(kernel, num_warmup=2, num_samples=2)
72+
mcmc.run(prng_key(), x, y)
73+
except ImportError:
74+
pass

0 commit comments

Comments
 (0)