Skip to content

Commit 56aeac7

Browse files
committed
Refactor memory profiler plugin for pytest
Updated the memory profiler plugin to utilize GPU memory statistics and issue warnings for high memory usage during test execution. Removed the previous memory tracking methods and streamlined the reporting mechanism to focus on peak GPU memory usage.
1 parent 55cd7ad commit 56aeac7

File tree

2 files changed

+22
-45
lines changed

2 files changed

+22
-45
lines changed

python/cuml/cuml/testing/plugins/__init__.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

python/cuml/cuml/testing/plugins/memory_profiler.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,34 @@
1313
# limitations under the License.
1414
#
1515

16-
import os
17-
import psutil
16+
import warnings
17+
from rmm.statistics import statistics, get_statistics
18+
import pytest
1819

19-
# Memory threshold in MB for reporting memory usage
20-
MEMORY_REPORT_THRESHOLD_MB = 1024
2120

21+
class HighMemoryUsageWarning(UserWarning):
22+
"""Warning emitted when a test exceeds the memory usage threshold."""
2223

23-
def get_process_memory():
24-
"""Get the current process memory usage in MB."""
25-
process = psutil.Process(os.getpid())
26-
return process.memory_info().rss / 1024 / 1024 # Convert to MB
24+
pass
2725

2826

29-
class MemoryProfiler:
30-
def __init__(self):
31-
self.start_memory = None
32-
self.max_memory = 0
27+
# Memory threshold in MB for reporting memory usage
28+
MEMORY_REPORT_THRESHOLD_MB = 1024
3329

34-
def pytest_runtest_setup(self, item):
35-
"""Record memory usage at test setup."""
36-
self.start_memory = get_process_memory()
3730

38-
def pytest_runtest_teardown(self, item):
39-
"""Record memory usage at test teardown and report if significant."""
40-
end_memory = get_process_memory()
41-
if self.start_memory is not None:
42-
memory_used = end_memory - self.start_memory
43-
self.max_memory = max(self.max_memory, end_memory)
44-
if memory_used > MEMORY_REPORT_THRESHOLD_MB:
45-
print(f"\nMemory usage for {item.nodeid}:")
46-
print(f" Start: {self.start_memory:.2f} MB")
47-
print(f" End: {end_memory:.2f} MB")
48-
print(f" Delta: {memory_used:.2f} MB")
49-
print(f" Max: {self.max_memory:.2f} MB")
31+
@pytest.hookimpl(hookwrapper=True)
32+
def pytest_runtest_call(item):
33+
"""Wrap test execution with GPU memory profiler."""
34+
with statistics():
35+
yield
5036

37+
# Check memory usage after test completion
38+
stats = get_statistics()
39+
peak_memory_mb = stats.peak_bytes / (1024 * 1024)
5140

52-
def pytest_configure(config):
53-
"""Register the memory profiler plugin."""
54-
config.pluginmanager.register(MemoryProfiler())
41+
if peak_memory_mb > MEMORY_REPORT_THRESHOLD_MB:
42+
msg = (
43+
f"Test {item.nodeid} used {peak_memory_mb:.2f} MB of GPU memory, "
44+
f"exceeding threshold of {MEMORY_REPORT_THRESHOLD_MB} MB"
45+
)
46+
warnings.warn(msg, HighMemoryUsageWarning)

0 commit comments

Comments
 (0)