diff --git a/python/cuml/cuml/testing/plugins/__init__.py b/python/cuml/cuml/testing/plugins/__init__.py deleted file mode 100644 index de13692a8a..0000000000 --- a/python/cuml/cuml/testing/plugins/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Copyright (c) 2020-2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/python/cuml/cuml/testing/plugins/memory_profiler.py b/python/cuml/cuml/testing/plugins/memory_profiler.py new file mode 100644 index 0000000000..c488b3b349 --- /dev/null +++ b/python/cuml/cuml/testing/plugins/memory_profiler.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings + +import pytest +from rmm.statistics import get_statistics, statistics + + +class HighMemoryUsageWarning(UserWarning): + """Warning emitted when a test exceeds the memory usage threshold.""" + + pass + + +# Memory threshold in MB for reporting memory usage +MEMORY_REPORT_THRESHOLD_MB = 1024 + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item): + """Wrap test execution with GPU memory profiler.""" + with statistics(): + yield + + # Check memory usage after test completion + stats = get_statistics() + peak_memory_mb = stats.peak_bytes / (1024 * 1024) + + if peak_memory_mb > MEMORY_REPORT_THRESHOLD_MB: + msg = ( + f"Test {item.nodeid} used {peak_memory_mb:.2f} MB of GPU memory, " + f"exceeding threshold of {MEMORY_REPORT_THRESHOLD_MB} MB" + ) + warnings.warn(msg, HighMemoryUsageWarning) diff --git a/python/cuml/cuml/tests/conftest.py b/python/cuml/cuml/tests/conftest.py index a9cbb38742..92fb3e1e5a 100644 --- a/python/cuml/cuml/tests/conftest.py +++ b/python/cuml/cuml/tests/conftest.py @@ -42,7 +42,10 @@ # ============================================================================= # Add the import here for any plugins that should be loaded EVERY TIME -pytest_plugins = "cuml.testing.plugins.quick_run_plugin" +pytest_plugins = [ + "cuml.testing.plugins.quick_run_plugin", + "cuml.testing.plugins.memory_profiler", +] def pytest_sessionstart(session): diff --git a/python/cuml/cuml/tests/test_linear_model.py b/python/cuml/cuml/tests/test_linear_model.py index 9457beaeb4..e404d4e4db 100644 --- a/python/cuml/cuml/tests/test_linear_model.py +++ b/python/cuml/cuml/tests/test_linear_model.py @@ -67,12 +67,10 @@ csr_matrix = cpu_only_import_from("scipy.sparse", "csr_matrix") -_ALGORITHMS = ["svd", "eig", "qr", "svd-qr", "svd-jacobi"] +ALGORITHMS = ["svd", "eig", "qr", "svd-qr", "svd-jacobi"] -algorithms = st.sampled_from(_ALGORITHMS) - -# TODO(24.08): remove this test +# TODO(25.08): remove this test def test_logreg_penalty_deprecation(): with pytest.warns( FutureWarning, @@ -523,13 +521,13 @@ def test_logistic_regression( assert np.array_equal(culog.intercept_, sklog.intercept_) +@pytest.mark.parametrize("penalty", [None, "l1", "l2", "elasticnet"]) @given( dtype=floating_dtypes(sizes=(32, 64)), - penalty=st.sampled_from((None, "l1", "l2", "elasticnet")), l1_ratio=st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0)), ) -@example(dtype=np.float32, penalty=None, l1_ratio=None) -@example(dtype=np.float64, penalty=None, l1_ratio=None) +@example(dtype=np.float32, l1_ratio=None) +@example(dtype=np.float64, l1_ratio=None) def test_logistic_regression_unscaled(dtype, penalty, l1_ratio): if penalty == "elasticnet": assume(l1_ratio is not None) @@ -579,41 +577,13 @@ def test_logistic_regression_model_default(dtype): assert culog.score(X_test, y_test) >= sklog.score(X_test, y_test) - 0.022 -@given( - dtype=st.sampled_from((np.float32, np.float64)), - order=st.sampled_from(("C", "F")), - sparse_input=st.booleans(), - fit_intercept=st.booleans(), - penalty=st.sampled_from((None, "l1", "l2")), -) -@example( - dtype=np.float32, - order="C", - sparse_input=False, - fit_intercept=True, - penalty=None, -) -@example( - dtype=np.float64, - order="C", - sparse_input=False, - fit_intercept=True, - penalty=None, -) -@example( - dtype=np.float32, - order="F", - sparse_input=False, - fit_intercept=True, - penalty=None, -) -@example( - dtype=np.float64, - order="F", - sparse_input=False, - fit_intercept=True, - penalty=None, -) +@pytest.mark.parametrize("order", ["C", "F"]) +@pytest.mark.parametrize("sparse_input", [True, False]) +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize("penalty", [None, "l1", "l2"]) +@given(dtype=floating_dtypes(sizes=(32, 64))) +@example(dtype=np.float32) +@example(dtype=np.float64) def test_logistic_regression_model_digits( dtype, order, sparse_input, fit_intercept, penalty ): @@ -640,8 +610,7 @@ def test_logistic_regression_model_digits( assert score >= acceptable_score -@given(dtype=st.sampled_from((np.float32, np.float64))) -@example(dtype=np.float32) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_logistic_regression_sparse_only(dtype, nlp_20news): # sklearn score with max_iter = 10000 @@ -662,6 +631,8 @@ def test_logistic_regression_sparse_only(dtype, nlp_20news): assert score >= acceptable_score +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize("sparse_input", [True, False]) @given( dataset=split_datasets( standard_classification_datasets( @@ -671,28 +642,12 @@ def test_logistic_regression_sparse_only(dtype, nlp_20news): n_informative=st.just(10), ) ), - fit_intercept=st.booleans(), - sparse_input=st.booleans(), ) @example( dataset=small_classification_dataset(np.float32), - fit_intercept=True, - sparse_input=False, -) -@example( - dataset=small_classification_dataset(np.float32), - fit_intercept=False, - sparse_input=True, -) -@example( - dataset=small_classification_dataset(np.float64), - fit_intercept=True, - sparse_input=False, ) @example( dataset=small_classification_dataset(np.float64), - fit_intercept=False, - sparse_input=True, ) def test_logistic_regression_decision_function( dataset, fit_intercept, sparse_input @@ -1105,21 +1060,17 @@ def test_elasticnet_solvers_eq(datatype, alpha, l1_ratio, nrows, column_info): assert np.corrcoef(cd.coef_, qn.coef_)[0, 1] > 0.98 +@pytest.mark.parametrize("algorithm", ALGORITHMS) +@pytest.mark.parametrize("xp", [np, cp]) +@pytest.mark.parametrize("copy", [True, False, ...]) @given( dataset=standard_regression_datasets( n_features=st.integers(min_value=1, max_value=10), - dtypes=floating_dtypes(sizes=(32, 64)), - ), - algorithm=algorithms, - xp=st.sampled_from([np, cp]), - copy=st.sampled_from((True, False, ...)), + dtypes=st.sampled_from((np.float32, np.float64)), + ) ) -@example(make_regression(n_features=1), "svd", cp, True) -@example(make_regression(n_features=1), "svd", cp, False) -@example(make_regression(n_features=1), "svd", cp, ...) -@example(make_regression(n_features=1), "svd", np, False) -@example(make_regression(n_features=2), "svd", cp, False) -@example(make_regression(n_features=2), "eig", np, False) +@example(dataset=make_regression(n_features=1)) +@example(dataset=make_regression(n_features=2)) def test_linear_regression_input_copy(dataset, algorithm, xp, copy): X, y = dataset X, y = xp.asarray(X), xp.asarray(y) diff --git a/wiki/python/DEVELOPER_GUIDE.md b/wiki/python/DEVELOPER_GUIDE.md index b10025c8b2..1b007e5053 100644 --- a/wiki/python/DEVELOPER_GUIDE.md +++ b/wiki/python/DEVELOPER_GUIDE.md @@ -47,31 +47,110 @@ The examples in the documentation are checked through doctest. To skip the check Examples subject to numerical imprecision, or that can't be reproduced consistently should be skipped. ## Testing and Unit Testing -We use [https://docs.pytest.org/en/latest/]() for writing and running tests. To see existing examples, refer to any of the `test_*.py` files in the folder `cuml/tests`. - -Some tests are run against inputs generated with [hypothesis](https://hypothesis.works/). See the `cuml/testing/strategies.py` module for custom strategies that can be used to test cuml estimators with diverse inputs. For example, use the `regression_datasets()` strategy to test random regression problems. - -When using hypothesis for testing, you must include at least one explicit example using the `@example` decorator alongside any `@given` strategies. This ensures that: -1. Every test has at least one deterministic test case that always runs -2. Critical edge cases are documented and tested consistently -3. Test failures can be reproduced reliably - -Note: While the explicit examples will always run in CI, the hypothesis-generated test cases (from `@given` strategies) only run during nightly testing by default. This ensures fast CI runs while still maintaining thorough testing coverage. - -Example of a valid hypothesis test: -```python -@example(dtype=np.float32, sparse_input=False) # baseline case, runs as part of PR CI -@example(dtype=np.float64, sparse_input=True) # edge case, runs as part of PR CI -@given( - dtype=st.sampled_from((np.float32, np.float64)), - sparse_input=st.booleans() -) # strategy-based cases, only runs during nightly tests -def test_my_estimator(dtype, sparse_input): - # Test implementation - pass +We use [pytest](https://docs.pytest.org/en/latest/) for writing and running tests. To see existing examples, refer to any of the `test_*.py` files in the folder `cuml/tests`. + +### Test Organization +- Keep all tests for a single estimator in one file, with exceptions for: + - Performance testing/benchmarking + - Generic estimator checks (e.g., `test_base.py`) +- Use small, focused datasets for correctness testing +- Only parametrize scale when it triggers alternate code paths + +### Test Input Generation +We support three main approaches for test input generation: + +1. **Fixtures** (`@pytest.fixture`): + - For shared setup/teardown code and resources + - Examples: random seeds, clients, loading test datasets + ```python + @pytest.fixture(scope="module") + def random_state(): + return 42 + ``` + +2. **Parametrization** (`@pytest.mark.parametrize`): + - For testing specific input combinations + - Good for hyperparameters and configurations + ```python + @pytest.mark.parametrize("solver", ["svd", "eig"]) + def test_estimator(solver): + pass + ``` + +3. **Hypothesis** (`@given`): + - For property-based testing with random inputs + - Must include at least one `@example` for deterministic testing + - Preferred for dataset generation + ```python + @example(dataset=small_regression_dataset(np.float32)) + @given(dataset=standard_regression_datasets()) + def test_estimator(dataset): + pass + ``` + +### Test Parameter Levels + +You can mark test parameters for different scales with (`unit_param`, `quality_param`, and `stress_param`). + +_Note: For dataset scaling, prefer using hypothesis, e.g. with `standard_regression_datasets()`._ + +We provide three test parameter levels: + +1. **Unit Tests** (`unit_param`): Small values for quick, basic functionality testing + ```python + unit_param(2) # For number of components + ``` + +2. **Quality Tests** (`quality_param`): Medium values for thorough testing + ```python + quality_param(10) # For number of components + ``` + +3. **Stress Tests** (`stress_param`): Large values for performance testing + ```python + stress_param(100) # For number of components + ``` + +Control via these pytest options: +- `--run_unit`: Unit tests (default) +- `--run_quality`: Quality tests +- `--run_stress`: Stress tests + +### Testing Guidelines + +1. **Accuracy Testing** + - Compare against recorded reference values when possible + - Document origin of reference values + - Use appropriate quality metrics for equivalent but different results + - Ensure reproducibility rather than using retry logic + +2. **Minimize resources** + - Use minimal dataset sizes + - Only test different scales if they would actually hit different code paths + +3. **Best Practices** + - Write small, focused tests + - Avoid duplication between test files + - Choose appropriate input generation method + - Make tests reproducible + - Document test assumptions and requirements + +### Running Tests +Tests must be run from the `python/cuml/` directory or one of its subdirectories. First build the package, then execute tests. + +```bash +./build.sh +cd python/cuml/ +pytest # Run all tests ``` -The test collection will fail if any test uses `@given` without an accompanying `@example`. +Common options: +- `pytest cuml/tests/test_kmeans.py` - Run specific file +- `pytest -k "test_kmeans"` - Run tests matching pattern +- `pytest --run_unit` - Run only unit tests +- `pytest -v` - Verbose output + +Running pytest from outside the `python/cuml/` directory will result in import errors. ## Device and Host memory allocations TODO: talk about enabling RMM here when it is ready