|
13 | 13 | # limitations under the License. |
14 | 14 | # |
15 | 15 |
|
16 | | -import os |
17 | | -import psutil |
| 16 | +import warnings |
| 17 | +from rmm.statistics import statistics, get_statistics |
| 18 | +import pytest |
18 | 19 |
|
19 | | -# Memory threshold in MB for reporting memory usage |
20 | | -MEMORY_REPORT_THRESHOLD_MB = 1024 |
21 | 20 |
|
| 21 | +class HighMemoryUsageWarning(UserWarning): |
| 22 | + """Warning emitted when a test exceeds the memory usage threshold.""" |
22 | 23 |
|
23 | | -def get_process_memory(): |
24 | | - """Get the current process memory usage in MB.""" |
25 | | - process = psutil.Process(os.getpid()) |
26 | | - return process.memory_info().rss / 1024 / 1024 # Convert to MB |
| 24 | + pass |
27 | 25 |
|
28 | 26 |
|
29 | | -class MemoryProfiler: |
30 | | - def __init__(self): |
31 | | - self.start_memory = None |
32 | | - self.max_memory = 0 |
| 27 | +# Memory threshold in MB for reporting memory usage |
| 28 | +MEMORY_REPORT_THRESHOLD_MB = 1024 |
33 | 29 |
|
34 | | - def pytest_runtest_setup(self, item): |
35 | | - """Record memory usage at test setup.""" |
36 | | - self.start_memory = get_process_memory() |
37 | 30 |
|
38 | | - def pytest_runtest_teardown(self, item): |
39 | | - """Record memory usage at test teardown and report if significant.""" |
40 | | - end_memory = get_process_memory() |
41 | | - if self.start_memory is not None: |
42 | | - memory_used = end_memory - self.start_memory |
43 | | - self.max_memory = max(self.max_memory, end_memory) |
44 | | - if memory_used > MEMORY_REPORT_THRESHOLD_MB: |
45 | | - print(f"\nMemory usage for {item.nodeid}:") |
46 | | - print(f" Start: {self.start_memory:.2f} MB") |
47 | | - print(f" End: {end_memory:.2f} MB") |
48 | | - print(f" Delta: {memory_used:.2f} MB") |
49 | | - print(f" Max: {self.max_memory:.2f} MB") |
| 31 | +@pytest.hookimpl(hookwrapper=True) |
| 32 | +def pytest_runtest_call(item): |
| 33 | + """Wrap test execution with GPU memory profiler.""" |
| 34 | + with statistics(): |
| 35 | + yield |
50 | 36 |
|
| 37 | + # Check memory usage after test completion |
| 38 | + stats = get_statistics() |
| 39 | + peak_memory_mb = stats.peak_bytes / (1024 * 1024) |
51 | 40 |
|
52 | | -def pytest_configure(config): |
53 | | - """Register the memory profiler plugin.""" |
54 | | - config.pluginmanager.register(MemoryProfiler()) |
| 41 | + if peak_memory_mb > MEMORY_REPORT_THRESHOLD_MB: |
| 42 | + msg = ( |
| 43 | + f"Test {item.nodeid} used {peak_memory_mb:.2f} MB of GPU memory, " |
| 44 | + f"exceeding threshold of {MEMORY_REPORT_THRESHOLD_MB} MB" |
| 45 | + ) |
| 46 | + warnings.warn(msg, HighMemoryUsageWarning) |
0 commit comments