Skip to content

Commit a9ca303

Browse files
Improve hessian matrix numerical precision implementing function using jax
1 parent 0466258 commit a9ca303

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
numpy
22
scipy
3+
jax
34
matplotlib
45
dwave-ocean-sdk
56
scikit-learn

src/qunfold/utils.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import sys
2+
import jax
23
import numpy as np
34
import scipy as sp
45
from tqdm import tqdm
5-
from scipy.optimize import approx_fprime
66
from qunfold import QUnfolder
77

88
try:
@@ -28,13 +28,17 @@ def compute_chi2(observed, expected):
2828
return chi2_red
2929

3030

31-
def approx_hessian(func, *point):
32-
precision = np.sqrt(np.finfo(dtype=np.float32).eps)
33-
epsilon = precision * np.array([max(1, x) for x in point])
34-
function = lambda point: func(*point)
35-
gradient = lambda point: approx_fprime(xk=point, f=function, epsilon=epsilon)
36-
xk = np.array([x for x in point])
37-
return approx_fprime(xk=xk, f=gradient, epsilon=epsilon)
31+
def approx_hessian(f, x):
32+
x = jax.numpy.array(x, dtype=float)
33+
n = len(x)
34+
hessian = jax.numpy.zeros(shape=(n, n))
35+
grad_f = jax.grad(f)
36+
for i in range(n):
37+
v = jax.numpy.zeros_like(x)
38+
v = v.at[i].set(1.0)
39+
_, hvp = jax.jvp(grad_f, primals=(x,), tangents=(v,))
40+
hessian = hessian.at[:, i].set(hvp)
41+
return hessian
3842

3943

4044
def lambda_optimizer(response, measured, truth, binning, num_reps=30, verbose=False, seed=None):

tests/test_utils.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,26 @@
55
from qunfold.utils import approx_hessian
66

77

8-
deg = st.integers(min_value=0, max_value=5)
8+
deg = st.integers(min_value=0, max_value=4)
99
xi = st.integers(min_value=0, max_value=1e6)
1010

1111

1212
@settings(deadline=None)
1313
@given(degrees=st.tuples(deg, deg, deg), point=st.tuples(xi, xi, xi))
1414
def test_approx_hessian(degrees, point):
1515
varlist = [sympy.Symbol(f"x{i}") for i in range(len(degrees))]
16-
poly_sympy = 0
16+
poly_sympy = 0.0
1717
for deg, var in zip(degrees, varlist):
18-
for exp in range(deg):
19-
if np.random.rand() < 0.3:
20-
coeff = np.random.rand() * 2 - 1
21-
poly_sympy += coeff * var**exp
18+
for exp in range(deg + 1):
19+
coeff = np.random.rand() * 5 - 2
20+
poly_sympy += coeff * var**exp
2221
n = len(varlist)
2322
hess_sympy = sympy.Matrix(np.zeros(shape=(n, n)))
2423
for i in range(n):
2524
for j in range(n):
2625
hess_sympy[i, j] = sympy.diff(poly_sympy, varlist[i], varlist[j])
2726
func = sympy.lambdify(args=varlist, expr=poly_sympy, modules="numpy")
2827
hess = sympy.lambdify(args=varlist, expr=hess_sympy, modules="numpy")
29-
hessian1 = approx_hessian(func, *point)
28+
hessian1 = approx_hessian(f=lambda x: func(*x), x=point)
3029
hessian2 = hess(*point)
31-
assert np.allclose(hessian1, hessian2, rtol=0.01, atol=1)
30+
assert np.allclose(hessian1, hessian2)

0 commit comments

Comments
 (0)