-
Notifications
You must be signed in to change notification settings - Fork 95
Add Numba-accelerated MBAR solver backend #570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I think the biggest issue we've had in any acceleration of pymbar is installation problems - @mikemhenry @Lnaden any thoughts about difficulties in installation with conda? Especially on differences when people are installing with GPU vs. CPU. |
|
numba is quite standard numpy acceleration package. The conda version supports
So pretty any arch could use it. |
|
@xiki-tempula I understand it SHOULD work easily. The question is whether it DOES work easily. Since it looks like there is dependence on JIT, then there's some questions as to how the conda installation should work - would be good to get insights from @mikemhenry and/or @Lnaden |
|
I think this is definitely interesting to have. It would be nice to have access to the current data you used for your benchmarking and verify that this indeed have the speedups with this data. Also, we probably would benefit from modularizing the solvers and backends themselves, instead of having them in the same I don't foresee any big issues with it, other than we would be having some new code to support but the code should be pretty stable and easy to work with (maybe we can discuss that further in a more detailed review of the changes). Just my two cents. |
@ijpulidos I may have some time this semester - can you file an issue on this with some of your ideas, even if half-baked? |
======================================================================
|
|
Just got back from vacation, will check this out tomorrow! |
Summary
This PR adds a new Numba-accelerated backend for solving the MBAR equations, providing significant performance improvements over both the default NumPy/SciPy and JAX implementations on CPU, especially for large datasets.
Changes
New Features
solver_protocol="numba"option to use the Numba-accelerated L-BFGS-B solverNUMBA_SOLVER_PROTOCOLconstant for direct access to the numba solver configurationImplementation Details
@njit(parallel=True, fastmath=True)for optimized parallel computationParameterErrorif Numba is not installed whensolver_protocol="numba"is requestedFiles Modified
pymbar/mbar_solvers.py: Added Numba-accelerated loss/gradient functions andL-BFGS-B-numbasolver methodpymbar/mbar.py: Added"numba"as a recognized solver protocol optionpymbar/tests/test_mbar.py: Added test to verify Numba solver matches default solver resultsUsage
Performance Benchmarks
Benchmarks performed on ABFE (Absolute Binding Free Energy) calculation data from GROMACS simulations.
Hardware & Software Specifications
pyhd8ed1ab_0)cpu_py311ha32d189_2)Free Leg (48 windows, 5001 samples/window, ~240K total samples)
Bound Leg (64 windows, 5001 samples/window, ~320K total samples)
Notes
Dependencies
The Numba backend requires the
numbapackage:pip install numba # or conda install numbaIf Numba is not installed and
solver_protocol="numba"is requested, aParameterErroris raised with installation instructions.Testing
All existing tests pass, plus a new test
test_mbar_numba_solver_matches_defaultverifies that the Numba solver produces results matching the default solver within numerical tolerance.