55
66from jax import numpy as jnp , random
77
8+ from numpyro import plate , prng_key , sample
89from numpyro .contrib .ecs_proxies import block_update
10+ from numpyro .contrib .module import random_haiku_module
11+ from numpyro .distributions import Cauchy , Normal
12+ from numpyro .handlers import seed
13+ from numpyro .infer import HMC , HMCECS , MCMC , SVI , Trace_ELBO
14+ from numpyro .infer .autoguide import AutoDelta
15+ from numpyro .optim import Adam
916
1017
1118@pytest .mark .parametrize ("num_blocks" , [1 , 2 , 50 , 100 ])
@@ -26,3 +33,42 @@ def test_block_update_partitioning(num_blocks):
2633 )
2734
2835 assert gibbs_state == new_gibbs_state
36+
37+
38+ def test_haiku_compatiable ():
39+ try :
40+ import haiku as hk # noqa: F401
41+
42+ data_points = 6
43+ x_dim = 4
44+
45+ def model (x , y ):
46+ net = random_haiku_module (
47+ "net" ,
48+ hk .transform (lambda x : hk .Linear (1 )(x )),
49+ prior = {"linear.b" : Cauchy (), "linear.w" : Normal ()},
50+ input_shape = (1 , x_dim ),
51+ )
52+
53+ with plate ("data" , data_points , subsample_size = 2 ) as idx :
54+ yb = y [idx ]
55+ xb = x [idx ]
56+ sample ("y" , Normal (net (xb ).squeeze ()), obs = yb )
57+
58+ x = jnp .ones ((data_points , x_dim ))
59+ y = jnp .array ((data_points , 0 ))
60+
61+ with seed (rng_seed = 0 ):
62+ svi = SVI (model , AutoDelta (model ), Adam (step_size = 1e-3 ), Trace_ELBO ())
63+ svi_result = svi .run (prng_key (), 1 , x , y )
64+ ref_params = {
65+ k .removesuffix ("_auto_loc" ): v for k , v in svi_result .params .items ()
66+ }
67+
68+ proxy = HMCECS .taylor_proxy (ref_params , degree = 2 )
69+ kernel = HMCECS (HMC (model ), num_blocks = 2 , proxy = proxy )
70+
71+ mcmc = MCMC (kernel , num_warmup = 2 , num_samples = 2 )
72+ mcmc .run (prng_key (), x , y )
73+ except ImportError :
74+ pass
0 commit comments