Skip to content

Commit dce0117

Browse files
committed
test: added test for _datacopied
1 parent fc8f698 commit dce0117

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/test_for_utils/test_banded_linalg.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,78 @@
66

77
### Imports ###
88

9+
from typing import List, Union
10+
911
import numpy as np
1012
import pytest
1113
from scipy.linalg import solve_banded as scipy_solve_banded
1214

1315
from chemotools.utils.banded_linalg import (
16+
_datacopied,
1417
conv_upper_chol_banded_to_lu_banded_storage,
1518
lu_banded,
1619
lu_solve_banded,
1720
slogdet_lu_banded,
1821
)
1922
from tests.test_for_utils.utils_funcs import get_banded_slogdet
2023

24+
### Constants ###
25+
26+
_ARRAY_TO_VIEW: np.ndarray = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
27+
_VIEW = _ARRAY_TO_VIEW[::]
28+
2129
### Test Suite ###
2230

2331

32+
@pytest.mark.parametrize(
33+
"arr, original, expected",
34+
[
35+
( # Number 0 Different arrays
36+
np.array([1, 2, 3]),
37+
np.array([1, 2, 3]),
38+
True,
39+
),
40+
( # Number 1 Array and list
41+
np.array([1, 2, 3]),
42+
[1, 2, 3],
43+
True,
44+
),
45+
( # Number 2 Different data types
46+
np.array([1, 2, 3]),
47+
np.array([1, 2, 3], dtype=np.float64),
48+
True,
49+
),
50+
( # Number 3 Different view and array
51+
_ARRAY_TO_VIEW[0:3],
52+
np.array([1, 2, 3]),
53+
False,
54+
),
55+
( # Number 4 Same array
56+
_ARRAY_TO_VIEW,
57+
_ARRAY_TO_VIEW,
58+
False,
59+
),
60+
( # Number 5 Same view and array
61+
_VIEW,
62+
_ARRAY_TO_VIEW,
63+
False,
64+
),
65+
],
66+
)
67+
def test_datacopied(
68+
arr: np.ndarray,
69+
original: Union[np.ndarray, List],
70+
expected: bool,
71+
) -> None:
72+
"""
73+
Tests the function that checks if a NumPy array has been copied from another array
74+
or list.
75+
76+
"""
77+
78+
assert _datacopied(arr, original) == expected
79+
80+
2481
@pytest.mark.parametrize("with_finite_check", [True, False])
2582
@pytest.mark.parametrize("overwrite_b", [True, False])
2683
@pytest.mark.parametrize("n_rhs", [0, 1, 2])

0 commit comments

Comments
 (0)