Skip to content

How to use window_adaptation in a multi-chain setting? #370

Answered by rlouf
bkozyrskiy asked this question in Q&A
Discussion options

You must be logged in to vote

I believe you can already do it by vectorizing the whole sampling scheme:

window_adaptation = blackjax.window_adaptation(
    blackjax.nuts,
    logposterior_fn,
    1000
)

def adapt_and_sample_one_chain(rng_key, initial_position):
    
    adapt_key, sample_key = jax.random.split(rng_key)

    state, kernel, _ = window_adaptation.run(
        adapt_key,
        initial_position,
    )

    chain = inference_loop(
        kernel,
        1000,
        sample_key,
        state
    )

    return chain

rng_keys = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, 10)
initial_positions = {"scale": jnp.ones(10), "coefs": 2. * jnp.ones(10)}

res = jax.vmap(adapt_and_sample_one_chain)(keys, 

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by rlouf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants