How to use window_adaptation in a multi-chain setting? #370
-
The problem is, that it is impossible to vectorize the method |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I believe you can already do it by vectorizing the whole sampling scheme: window_adaptation = blackjax.window_adaptation(
blackjax.nuts,
logposterior_fn,
1000
)
def adapt_and_sample_one_chain(rng_key, initial_position):
adapt_key, sample_key = jax.random.split(rng_key)
state, kernel, _ = window_adaptation.run(
adapt_key,
initial_position,
)
chain = inference_loop(
kernel,
1000,
sample_key,
state
)
return chain
rng_keys = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, 10)
initial_positions = {"scale": jnp.ones(10), "coefs": 2. * jnp.ones(10)}
res = jax.vmap(adapt_and_sample_one_chain)(keys, initial_positions) If you need to do something lower-level with the adaptation, this will be possible once #276 is merged over the next couple of days (see this comment: #276 (comment)). |
Beta Was this translation helpful? Give feedback.
I believe you can already do it by vectorizing the whole sampling scheme: