Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions python/cuml/cuml/testing/plugins/__init__.py

This file was deleted.

47 changes: 47 additions & 0 deletions python/cuml/cuml/testing/plugins/memory_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import warnings

import pytest
from rmm.statistics import get_statistics, statistics


class HighMemoryUsageWarning(UserWarning):
"""Warning emitted when a test exceeds the memory usage threshold."""

pass


# Memory threshold in MB for reporting memory usage
MEMORY_REPORT_THRESHOLD_MB = 1024


@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item):
"""Wrap test execution with GPU memory profiler."""
with statistics():
yield

# Check memory usage after test completion
stats = get_statistics()
peak_memory_mb = stats.peak_bytes / (1024 * 1024)

if peak_memory_mb > MEMORY_REPORT_THRESHOLD_MB:
msg = (
f"Test {item.nodeid} used {peak_memory_mb:.2f} MB of GPU memory, "
f"exceeding threshold of {MEMORY_REPORT_THRESHOLD_MB} MB"
)
warnings.warn(msg, HighMemoryUsageWarning)
5 changes: 4 additions & 1 deletion python/cuml/cuml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
# =============================================================================

# Add the import here for any plugins that should be loaded EVERY TIME
pytest_plugins = "cuml.testing.plugins.quick_run_plugin"
pytest_plugins = [
"cuml.testing.plugins.quick_run_plugin",
"cuml.testing.plugins.memory_profiler",
]


def pytest_sessionstart(session):
Expand Down
93 changes: 22 additions & 71 deletions python/cuml/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,10 @@
csr_matrix = cpu_only_import_from("scipy.sparse", "csr_matrix")


_ALGORITHMS = ["svd", "eig", "qr", "svd-qr", "svd-jacobi"]
ALGORITHMS = ["svd", "eig", "qr", "svd-qr", "svd-jacobi"]

algorithms = st.sampled_from(_ALGORITHMS)


# TODO(24.08): remove this test
# TODO(25.08): remove this test
def test_logreg_penalty_deprecation():
with pytest.warns(
FutureWarning,
Expand Down Expand Up @@ -523,13 +521,13 @@ def test_logistic_regression(
assert np.array_equal(culog.intercept_, sklog.intercept_)


@pytest.mark.parametrize("penalty", [None, "l1", "l2", "elasticnet"])
@given(
dtype=floating_dtypes(sizes=(32, 64)),
penalty=st.sampled_from((None, "l1", "l2", "elasticnet")),
l1_ratio=st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0)),
)
@example(dtype=np.float32, penalty=None, l1_ratio=None)
@example(dtype=np.float64, penalty=None, l1_ratio=None)
@example(dtype=np.float32, l1_ratio=None)
@example(dtype=np.float64, l1_ratio=None)
def test_logistic_regression_unscaled(dtype, penalty, l1_ratio):
if penalty == "elasticnet":
assume(l1_ratio is not None)
Expand Down Expand Up @@ -579,41 +577,13 @@ def test_logistic_regression_model_default(dtype):
assert culog.score(X_test, y_test) >= sklog.score(X_test, y_test) - 0.022


@given(
dtype=st.sampled_from((np.float32, np.float64)),
order=st.sampled_from(("C", "F")),
sparse_input=st.booleans(),
fit_intercept=st.booleans(),
penalty=st.sampled_from((None, "l1", "l2")),
)
@example(
dtype=np.float32,
order="C",
sparse_input=False,
fit_intercept=True,
penalty=None,
)
@example(
dtype=np.float64,
order="C",
sparse_input=False,
fit_intercept=True,
penalty=None,
)
@example(
dtype=np.float32,
order="F",
sparse_input=False,
fit_intercept=True,
penalty=None,
)
@example(
dtype=np.float64,
order="F",
sparse_input=False,
fit_intercept=True,
penalty=None,
)
@pytest.mark.parametrize("order", ["C", "F"])
@pytest.mark.parametrize("sparse_input", [True, False])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("penalty", [None, "l1", "l2"])
@given(dtype=floating_dtypes(sizes=(32, 64)))
@example(dtype=np.float32)
@example(dtype=np.float64)
def test_logistic_regression_model_digits(
dtype, order, sparse_input, fit_intercept, penalty
):
Expand All @@ -640,8 +610,7 @@ def test_logistic_regression_model_digits(
assert score >= acceptable_score


@given(dtype=st.sampled_from((np.float32, np.float64)))
@example(dtype=np.float32)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_logistic_regression_sparse_only(dtype, nlp_20news):

# sklearn score with max_iter = 10000
Expand All @@ -662,6 +631,8 @@ def test_logistic_regression_sparse_only(dtype, nlp_20news):
assert score >= acceptable_score


@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("sparse_input", [True, False])
@given(
dataset=split_datasets(
standard_classification_datasets(
Expand All @@ -671,28 +642,12 @@ def test_logistic_regression_sparse_only(dtype, nlp_20news):
n_informative=st.just(10),
)
),
fit_intercept=st.booleans(),
sparse_input=st.booleans(),
)
@example(
dataset=small_classification_dataset(np.float32),
fit_intercept=True,
sparse_input=False,
)
@example(
dataset=small_classification_dataset(np.float32),
fit_intercept=False,
sparse_input=True,
)
@example(
dataset=small_classification_dataset(np.float64),
fit_intercept=True,
sparse_input=False,
)
@example(
dataset=small_classification_dataset(np.float64),
fit_intercept=False,
sparse_input=True,
)
def test_logistic_regression_decision_function(
dataset, fit_intercept, sparse_input
Expand Down Expand Up @@ -1105,21 +1060,17 @@ def test_elasticnet_solvers_eq(datatype, alpha, l1_ratio, nrows, column_info):
assert np.corrcoef(cd.coef_, qn.coef_)[0, 1] > 0.98


@pytest.mark.parametrize("algorithm", ALGORITHMS)
@pytest.mark.parametrize("xp", [np, cp])
@pytest.mark.parametrize("copy", [True, False, ...])
@given(
dataset=standard_regression_datasets(
n_features=st.integers(min_value=1, max_value=10),
dtypes=floating_dtypes(sizes=(32, 64)),
),
algorithm=algorithms,
xp=st.sampled_from([np, cp]),
copy=st.sampled_from((True, False, ...)),
dtypes=st.sampled_from((np.float32, np.float64)),
)
)
@example(make_regression(n_features=1), "svd", cp, True)
@example(make_regression(n_features=1), "svd", cp, False)
@example(make_regression(n_features=1), "svd", cp, ...)
@example(make_regression(n_features=1), "svd", np, False)
@example(make_regression(n_features=2), "svd", cp, False)
@example(make_regression(n_features=2), "eig", np, False)
@example(dataset=make_regression(n_features=1))
@example(dataset=make_regression(n_features=2))
def test_linear_regression_input_copy(dataset, algorithm, xp, copy):
X, y = dataset
X, y = xp.asarray(X), xp.asarray(y)
Expand Down
125 changes: 102 additions & 23 deletions wiki/python/DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,110 @@ The examples in the documentation are checked through doctest. To skip the check
Examples subject to numerical imprecision, or that can't be reproduced consistently should be skipped.

## Testing and Unit Testing
We use [https://docs.pytest.org/en/latest/]() for writing and running tests. To see existing examples, refer to any of the `test_*.py` files in the folder `cuml/tests`.

Some tests are run against inputs generated with [hypothesis](https://hypothesis.works/). See the `cuml/testing/strategies.py` module for custom strategies that can be used to test cuml estimators with diverse inputs. For example, use the `regression_datasets()` strategy to test random regression problems.

When using hypothesis for testing, you must include at least one explicit example using the `@example` decorator alongside any `@given` strategies. This ensures that:
1. Every test has at least one deterministic test case that always runs
2. Critical edge cases are documented and tested consistently
3. Test failures can be reproduced reliably

Note: While the explicit examples will always run in CI, the hypothesis-generated test cases (from `@given` strategies) only run during nightly testing by default. This ensures fast CI runs while still maintaining thorough testing coverage.

Example of a valid hypothesis test:
```python
@example(dtype=np.float32, sparse_input=False) # baseline case, runs as part of PR CI
@example(dtype=np.float64, sparse_input=True) # edge case, runs as part of PR CI
@given(
dtype=st.sampled_from((np.float32, np.float64)),
sparse_input=st.booleans()
) # strategy-based cases, only runs during nightly tests
def test_my_estimator(dtype, sparse_input):
# Test implementation
pass
We use [pytest](https://docs.pytest.org/en/latest/) for writing and running tests. To see existing examples, refer to any of the `test_*.py` files in the folder `cuml/tests`.

### Test Organization
- Keep all tests for a single estimator in one file, with exceptions for:
- Performance testing/benchmarking
- Generic estimator checks (e.g., `test_base.py`)
- Use small, focused datasets for correctness testing
- Only parametrize scale when it triggers alternate code paths

### Test Input Generation
We support three main approaches for test input generation:

1. **Fixtures** (`@pytest.fixture`):
- For shared setup/teardown code and resources
- Examples: random seeds, clients, loading test datasets
```python
@pytest.fixture(scope="module")
def random_state():
return 42
```

2. **Parametrization** (`@pytest.mark.parametrize`):
- For testing specific input combinations
- Good for hyperparameters and configurations
```python
@pytest.mark.parametrize("solver", ["svd", "eig"])
def test_estimator(solver):
pass
```

3. **Hypothesis** (`@given`):
- For property-based testing with random inputs
- Must include at least one `@example` for deterministic testing
- Preferred for dataset generation
```python
@example(dataset=small_regression_dataset(np.float32))
@given(dataset=standard_regression_datasets())
def test_estimator(dataset):
pass
```

### Test Parameter Levels

You can mark test parameters for different scales with (`unit_param`, `quality_param`, and `stress_param`).

_Note: For dataset scaling, prefer using hypothesis, e.g. with `standard_regression_datasets()`._

We provide three test parameter levels:

1. **Unit Tests** (`unit_param`): Small values for quick, basic functionality testing
```python
unit_param(2) # For number of components
```

2. **Quality Tests** (`quality_param`): Medium values for thorough testing
```python
quality_param(10) # For number of components
```

3. **Stress Tests** (`stress_param`): Large values for performance testing
```python
stress_param(100) # For number of components
```

Control via these pytest options:
- `--run_unit`: Unit tests (default)
- `--run_quality`: Quality tests
- `--run_stress`: Stress tests

### Testing Guidelines

1. **Accuracy Testing**
- Compare against recorded reference values when possible
- Document origin of reference values
- Use appropriate quality metrics for equivalent but different results
- Ensure reproducibility rather than using retry logic

2. **Minimize resources**
- Use minimal dataset sizes
- Only test different scales if they would actually hit different code paths

3. **Best Practices**
- Write small, focused tests
- Avoid duplication between test files
- Choose appropriate input generation method
- Make tests reproducible
- Document test assumptions and requirements

### Running Tests
Tests must be run from the `python/cuml/` directory or one of its subdirectories. First build the package, then execute tests.

```bash
./build.sh
cd python/cuml/
pytest # Run all tests
```

The test collection will fail if any test uses `@given` without an accompanying `@example`.
Common options:
- `pytest cuml/tests/test_kmeans.py` - Run specific file
- `pytest -k "test_kmeans"` - Run tests matching pattern
- `pytest --run_unit` - Run only unit tests
- `pytest -v` - Verbose output

Running pytest from outside the `python/cuml/` directory will result in import errors.

## Device and Host memory allocations
TODO: talk about enabling RMM here when it is ready
Expand Down