-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blackjax sampler fix for breaking change / enable progress bar under parallel chain_method #7453
Conversation
It's enough if we're compatible with the latest blackjax releases. We can raise a runtime informative error if we know the installed version of blackjax is too old to work, directing users to update it |
|
Thanks for the new release Looks like it needs to get into conda first |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7453 +/- ##
==========================================
- Coverage 92.20% 92.17% -0.03%
==========================================
Files 103 103
Lines 17301 17258 -43
==========================================
- Hits 15952 15908 -44
- Misses 1349 1350 +1
|
…parallel chain_method (pymc-devs#7453) * remove blackjax pmap warning * use gen_scan_fn * remove labels * retrigger checks * retrigger checks
blackjax-devs/blackjax#712 changes the expected jax.lax.scan carry in
progress_bar_scan
. Since pymc's external blackjax sampler directly usesprogress_bar_scan
it will break whenprogressbar=True
. This change switches to use a new wrapper to hide the progress bar details. In addition it enables the use of progress bars underchain_method="parallel"
.I think any breaking issues can be handled by restricting blackjax version numbers. However, I'm not sure how to properly do that?
And of course for now tests are expected to fail until the changes show in a blackjax release.
PRs that are dependencies:
blackjax-devs/blackjax#712
blackjax-devs/blackjax#716
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7453.org.readthedocs.build/en/7453/