From dce0117ed8b2f7a324c319432a1a32771033aaff Mon Sep 17 00:00:00 2001 From: MothNik Date: Mon, 20 May 2024 00:35:29 +0200 Subject: [PATCH] test: added test for `_datacopied` --- tests/test_for_utils/test_banded_linalg.py | 57 ++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/test_for_utils/test_banded_linalg.py b/tests/test_for_utils/test_banded_linalg.py index d3ece0a..ba40498 100644 --- a/tests/test_for_utils/test_banded_linalg.py +++ b/tests/test_for_utils/test_banded_linalg.py @@ -6,11 +6,14 @@ ### Imports ### +from typing import List, Union + import numpy as np import pytest from scipy.linalg import solve_banded as scipy_solve_banded from chemotools.utils.banded_linalg import ( + _datacopied, conv_upper_chol_banded_to_lu_banded_storage, lu_banded, lu_solve_banded, @@ -18,9 +21,63 @@ ) from tests.test_for_utils.utils_funcs import get_banded_slogdet +### Constants ### + +_ARRAY_TO_VIEW: np.ndarray = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +_VIEW = _ARRAY_TO_VIEW[::] + ### Test Suite ### +@pytest.mark.parametrize( + "arr, original, expected", + [ + ( # Number 0 Different arrays + np.array([1, 2, 3]), + np.array([1, 2, 3]), + True, + ), + ( # Number 1 Array and list + np.array([1, 2, 3]), + [1, 2, 3], + True, + ), + ( # Number 2 Different data types + np.array([1, 2, 3]), + np.array([1, 2, 3], dtype=np.float64), + True, + ), + ( # Number 3 Different view and array + _ARRAY_TO_VIEW[0:3], + np.array([1, 2, 3]), + False, + ), + ( # Number 4 Same array + _ARRAY_TO_VIEW, + _ARRAY_TO_VIEW, + False, + ), + ( # Number 5 Same view and array + _VIEW, + _ARRAY_TO_VIEW, + False, + ), + ], +) +def test_datacopied( + arr: np.ndarray, + original: Union[np.ndarray, List], + expected: bool, +) -> None: + """ + Tests the function that checks if a NumPy array has been copied from another array + or list. + + """ + + assert _datacopied(arr, original) == expected + + @pytest.mark.parametrize("with_finite_check", [True, False]) @pytest.mark.parametrize("overwrite_b", [True, False]) @pytest.mark.parametrize("n_rhs", [0, 1, 2])