You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Recently, a couple of samplers have been proposed by the group I work with (https://arxiv.org/abs/2503.01707, https://arxiv.org/abs/2212.08549) as alternatives to NUTS HMC (see issue #1662
). Since they appear to be quite a bit faster than NUTS (at least on benchmark problems I've tried), and relatively simple, I'm interested in adding them to NumPyro, but wanted to get some advice.
Currently, implementations exist in Blackjax. In an ideal world, I'd make a new class like class AdjustedMicrocanonical(numpyro.infer.mcmc.MCMCKernel) which basically just wraps Blackjax.
In addition, my eventual goal would be to add not just the kernel, but also the tuning scheme (which is key to good performance). I'm curious if there's a straightforward way to do that.
Motivation
While it's easy to write a model in NumPyro and extract the density, then use Blackjax for inference, we want to give users more direct access (basically for the purpose of increasing discoverability).
The text was updated successfully, but these errors were encountered:
one issue with relying on a blackjax implementation is that the numpyro implementation will tend to break every so often as blackjax is refactored. that has certainly been our experience with wrapping jaxns => https://num.pyro.ai/en/latest/contrib.html#nested-sampling
Feature Summary
Recently, a couple of samplers have been proposed by the group I work with (https://arxiv.org/abs/2503.01707, https://arxiv.org/abs/2212.08549) as alternatives to NUTS HMC (see issue #1662
). Since they appear to be quite a bit faster than NUTS (at least on benchmark problems I've tried), and relatively simple, I'm interested in adding them to NumPyro, but wanted to get some advice.
Currently, implementations exist in Blackjax. In an ideal world, I'd make a new class like
class AdjustedMicrocanonical(numpyro.infer.mcmc.MCMCKernel)
which basically just wraps Blackjax.In addition, my eventual goal would be to add not just the kernel, but also the tuning scheme (which is key to good performance). I'm curious if there's a straightforward way to do that.
Motivation
While it's easy to write a model in NumPyro and extract the density, then use Blackjax for inference, we want to give users more direct access (basically for the purpose of increasing discoverability).
The text was updated successfully, but these errors were encountered: