Skip to content

Conversation

@xiki-tempula
Copy link

Summary

This PR adds a new Numba-accelerated backend for solving the MBAR equations, providing significant performance improvements over both the default NumPy/SciPy and JAX implementations on CPU, especially for large datasets.

Changes

New Features

  • Added solver_protocol="numba" option to use the Numba-accelerated L-BFGS-B solver
  • New NUMBA_SOLVER_PROTOCOL constant for direct access to the numba solver configuration

Implementation Details

  • Integrated Numba-accelerated MBAR loss and gradient computation based on the Fast MBAR method (https://doi.org/10.1021/acs.jctc.8b01010)
  • The Numba backend uses @njit(parallel=True, fastmath=True) for optimized parallel computation
  • Graceful error handling: raises ParameterError if Numba is not installed when solver_protocol="numba" is requested
  • Maintains full compatibility with existing JAX and NumPy backends

Files Modified

  • pymbar/mbar_solvers.py: Added Numba-accelerated loss/gradient functions and L-BFGS-B-numba solver method
  • pymbar/mbar.py: Added "numba" as a recognized solver protocol option
  • pymbar/tests/test_mbar.py: Added test to verify Numba solver matches default solver results

Usage

from pymbar import MBAR

# Use the numba-accelerated solver
mbar = MBAR(u_kn, N_k, solver_protocol="numba")

# Or use the protocol constant directly
from pymbar import NUMBA_SOLVER_PROTOCOL
mbar = MBAR(u_kn, N_k, solver_protocol=NUMBA_SOLVER_PROTOCOL)

Performance Benchmarks

Benchmarks performed on ABFE (Absolute Binding Free Energy) calculation data from GROMACS simulations.

Hardware & Software Specifications

  • Model: MacBook Pro
  • Chip: Apple M1 Pro
  • Cores: 10 (8 performance + 2 efficiency)
  • Memory: 16 GB
  • Environment: CPU-only (no GPU acceleration)
  • JAX: 0.7.2 (conda-forge, pyhd8ed1ab_0)
  • jaxlib: 0.7.2 (conda-forge, cpu_py311ha32d189_2)

Free Leg (48 windows, 5001 samples/window, ~240K total samples)

Backend Time Speedup (vs Numba)
JAX (BFGS) 231.59s 37.2x slower
NumPy (L-BFGS-B) 75.57s 12.2x slower
Numba (L-BFGS-B) 6.22s 1.0x

Bound Leg (64 windows, 5001 samples/window, ~320K total samples)

Backend Time Speedup (vs Numba)
JAX (BFGS) 273.02s 49.1x slower
NumPy (L-BFGS-B) 191.94s 34.5x slower
Numba (L-BFGS-B) 5.56s 1.0x

Notes

  • Numba is 12-35x faster than NumPy and 37-49x faster than JAX on CPU-only setups
  • First call includes JIT compilation overhead (~2-3s); subsequent calls are faster
  • Results are numerically identical across all backends (max |Δf_k| < 1e-3)
  • JAX backend uses BFGS optimizer while NumPy/Numba use L-BFGS-B
  • JAX may perform better on GPU; these benchmarks are CPU-only

Dependencies

The Numba backend requires the numba package:

pip install numba
# or
conda install numba

If Numba is not installed and solver_protocol="numba" is requested, a ParameterError is raised with installation instructions.

Testing

All existing tests pass, plus a new test test_mbar_numba_solver_matches_default verifies that the Numba solver produces results matching the default solver within numerical tolerance.

@xiki-tempula xiki-tempula marked this pull request as ready for review January 27, 2026 14:16
@xiki-tempula
Copy link
Author

@msuruzhon

@mrshirts
Copy link
Collaborator

mrshirts commented Jan 27, 2026

I think the biggest issue we've had in any acceleration of pymbar is installation problems - @mikemhenry @Lnaden any thoughts about difficulties in installation with conda? Especially on differences when people are installing with GPU vs. CPU.

@mrshirts mrshirts requested review from Lnaden and mikemhenry January 27, 2026 19:35
@xiki-tempula
Copy link
Author

numba is quite standard numpy acceleration package. The conda version supports

  • win-64
  • macOS-arm64
  • macOS-64
  • linux-aarch64
  • linux-ppc64le
  • linux-64

So pretty any arch could use it.

@mrshirts
Copy link
Collaborator

@xiki-tempula I understand it SHOULD work easily. The question is whether it DOES work easily. Since it looks like there is dependence on JIT, then there's some questions as to how the conda installation should work - would be good to get insights from @mikemhenry and/or @Lnaden

@ijpulidos
Copy link
Contributor

I think this is definitely interesting to have. It would be nice to have access to the current data you used for your benchmarking and verify that this indeed have the speedups with this data.

Also, we probably would benefit from modularizing the solvers and backends themselves, instead of having them in the same mbar_solvers.py module.

I don't foresee any big issues with it, other than we would be having some new code to support but the code should be pretty stable and easy to work with (maybe we can discuss that further in a more detailed review of the changes). Just my two cents.

@mrshirts
Copy link
Collaborator

Also, we probably would benefit from modularizing the solvers and backends themselves, instead of having them in the same mbar_solvers.py module.

@ijpulidos I may have some time this semester - can you file an issue on this with some of your ideas, even if half-baked?

@xiki-tempula
Copy link
Author

import time
import numpy as np
from pymbar import MBAR

# Generate synthetic test data
np.random.seed(42)
K = 96  # number of states
N = 5000  # samples per state

print('='*70)
print('BENCHMARK: Synthetic Data (96 states, 5000 samples/state)')
print('='*70)

print('\nGenerating synthetic u_kn...', flush=True)
# Create harmonic oscillator-like potential energies
# u_kn[k, n] = 0.5 * (x_n - k)^2 where x_n is sampled from state that generated it
N_k = np.array([N] * K)
total_samples = np.sum(N_k)
u_kn = np.zeros((K, total_samples))

# Generate samples and compute energies
idx = 0
for k_orig in range(K):
    # Samples from state k_orig centered at k_orig
    x_samples = np.random.normal(loc=k_orig, scale=1.0, size=N)
    for n in range(N):
        for k in range(K):
            # Energy of sample in state k
            u_kn[k, idx] = 0.5 * (x_samples[n] - k)**2
        idx += 1

print(f'  {K} states, {N} samples/state, {total_samples} total samples')
print(f'  u_kn shape: {u_kn.shape}')

NUMPY_PROTOCOL = ({'method': 'L-BFGS-B', 'continuation': True, 'options': {'maxiter': 10000}},)

# Numba benchmark
print('\n[Numba] warmup...', flush=True)
t0 = time.perf_counter()
mbar_numba = MBAR(u_kn, N_k, solver_protocol='numba')
t_numba_warmup = time.perf_counter() - t0
print(f'  Warmup: {t_numba_warmup:.2f}s')

print('[Numba] timed run...', flush=True)
t0 = time.perf_counter()
mbar_numba = MBAR(u_kn, N_k, solver_protocol='numba')
t_numba = time.perf_counter() - t0
print(f'  Time: {t_numba:.2f}s')

# NumPy benchmark
print('\n[NumPy] warmup...', flush=True)
t0 = time.perf_counter()
mbar_numpy = MBAR(u_kn, N_k, solver_protocol=NUMPY_PROTOCOL)
t_numpy_warmup = time.perf_counter() - t0
print(f'  Warmup: {t_numpy_warmup:.2f}s')

print('[NumPy] timed run...', flush=True)
t0 = time.perf_counter()
mbar_numpy = MBAR(u_kn, N_k, solver_protocol=NUMPY_PROTOCOL)
t_numpy = time.perf_counter() - t0
print(f'  Time: {t_numpy:.2f}s')

print('\n' + '='*70)
print('SUMMARY')
print('='*70)
print(f'Numba: {t_numba:.2f}s')
print(f'NumPy: {t_numpy:.2f}s')
print(f'Speedup (Numba vs NumPy): {t_numpy/t_numba:.1f}x')
print(f'\nVerification: f_k[-1] Numba={mbar_numba.f_k[-1]:.6f}, NumPy={mbar_numpy.f_k[-1]:.6f}')

======================================================================
BENCHMARK: Synthetic Data (96 states, 5000 samples/state)

Generating synthetic u_kn...
96 states, 5000 samples/state, 480000 total samples
u_kn shape: (96, 480000)

[Numba] warmup...
Warmup: 13.43s
[Numba] timed run...
Time: 7.00s

[NumPy] warmup...
Warmup: 28.14s
[NumPy] timed run...
Time: 31.96s

======================================================================
SUMMARY

Numba: 7.00s
NumPy: 31.96s
Speedup (Numba vs NumPy): 4.6x

@mikemhenry
Copy link
Contributor

Just got back from vacation, will check this out tomorrow!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants