|
5 | 5 | from qunfold.utils import approx_hessian
|
6 | 6 |
|
7 | 7 |
|
8 |
| -deg = st.integers(min_value=0, max_value=5) |
| 8 | +deg = st.integers(min_value=0, max_value=4) |
9 | 9 | xi = st.integers(min_value=0, max_value=1e6)
|
10 | 10 |
|
11 | 11 |
|
12 | 12 | @settings(deadline=None)
|
13 | 13 | @given(degrees=st.tuples(deg, deg, deg), point=st.tuples(xi, xi, xi))
|
14 | 14 | def test_approx_hessian(degrees, point):
|
15 | 15 | varlist = [sympy.Symbol(f"x{i}") for i in range(len(degrees))]
|
16 |
| - poly_sympy = 0 |
| 16 | + poly_sympy = 0.0 |
17 | 17 | for deg, var in zip(degrees, varlist):
|
18 |
| - for exp in range(deg): |
19 |
| - if np.random.rand() < 0.3: |
20 |
| - coeff = np.random.rand() * 2 - 1 |
21 |
| - poly_sympy += coeff * var**exp |
| 18 | + for exp in range(deg + 1): |
| 19 | + coeff = np.random.rand() * 5 - 2 |
| 20 | + poly_sympy += coeff * var**exp |
22 | 21 | n = len(varlist)
|
23 | 22 | hess_sympy = sympy.Matrix(np.zeros(shape=(n, n)))
|
24 | 23 | for i in range(n):
|
25 | 24 | for j in range(n):
|
26 | 25 | hess_sympy[i, j] = sympy.diff(poly_sympy, varlist[i], varlist[j])
|
27 | 26 | func = sympy.lambdify(args=varlist, expr=poly_sympy, modules="numpy")
|
28 | 27 | hess = sympy.lambdify(args=varlist, expr=hess_sympy, modules="numpy")
|
29 |
| - hessian1 = approx_hessian(func, *point) |
| 28 | + hessian1 = approx_hessian(f=lambda x: func(*x), x=point) |
30 | 29 | hessian2 = hess(*point)
|
31 |
| - assert np.allclose(hessian1, hessian2, rtol=0.01, atol=1) |
| 30 | + assert np.allclose(hessian1, hessian2) |
0 commit comments