|
1 |
| -import ctypes |
2 |
| -import os |
3 | 1 | import unittest
|
4 | 2 | import zeroize
|
5 | 3 | import numpy as np
|
6 | 4 |
|
7 | 5 |
|
8 |
| -# Lock memory using ctypes |
9 |
| -def lock_memory(): |
10 |
| - libc = ctypes.CDLL("libc.so.6") |
11 |
| - # Lock all current and future pages from being swapped out |
12 |
| - libc.mlockall(ctypes.c_int(0x02 | 0x04)) # MCL_CURRENT | MCL_FUTURE |
| 6 | +import ctypes |
| 7 | + |
| 8 | + |
| 9 | +# Load the C standard library |
| 10 | +LIBC = ctypes.CDLL("libc.so.6") |
| 11 | +MLOCK = LIBC.mlock |
| 12 | +MUNLOCK = LIBC.munlock |
| 13 | + |
| 14 | +# Define mlock and munlock argument types |
| 15 | +MLOCK.argtypes = [ctypes.c_void_p, ctypes.c_size_t] |
| 16 | +MUNLOCK.argtypes = [ctypes.c_void_p, ctypes.c_size_t] |
| 17 | + |
| 18 | + |
| 19 | +def lock_memory(buffer): |
| 20 | + """Locks the memory of the given buffer.""" |
| 21 | + address = ctypes.addressof(ctypes.c_char.from_buffer(buffer)) |
| 22 | + size = len(buffer) |
| 23 | + if MLOCK(address, size) != 0: |
| 24 | + raise RuntimeError("Failed to lock memory") |
13 | 25 |
|
14 | 26 |
|
15 |
| -def unlock_memory(): |
16 |
| - libc = ctypes.CDLL("libc.so.6") |
17 |
| - # Unlock all locked pages |
18 |
| - libc.munlockall() |
| 27 | +def unlock_memory(buffer): |
| 28 | + """Unlocks the memory of the given buffer.""" |
| 29 | + address = ctypes.addressof(ctypes.c_char.from_buffer(buffer)) |
| 30 | + size = len(buffer) |
| 31 | + if MUNLOCK(address, size) != 0: |
| 32 | + raise RuntimeError("Failed to unlock memory") |
19 | 33 |
|
20 | 34 |
|
21 | 35 | SIZES_MB = [
|
@@ -43,44 +57,46 @@ def unlock_memory():
|
43 | 57 | class TestStringMethods(unittest.TestCase):
|
44 | 58 |
|
45 | 59 | def test_zeroize1(self):
|
46 |
| - lock_memory() |
47 |
| - |
48 |
| - arr = bytearray(b"1234567890") |
49 |
| - zeroize.zeroize1(arr) |
50 |
| - self.assertEqual(arr, bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")) |
| 60 | + try: |
| 61 | + arr = bytearray(b"1234567890") |
| 62 | + lock_memory(arr) |
| 63 | + zeroize.zeroize1(arr) |
| 64 | + self.assertEqual( |
| 65 | + arr, bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") |
| 66 | + ) |
51 | 67 |
|
52 |
| - unlock_memory() |
| 68 | + finally: |
| 69 | + unlock_memory(arr) |
53 | 70 |
|
54 | 71 | def test_zeroize_np(self):
|
55 |
| - lock_memory() |
| 72 | + try: |
| 73 | + arr = np.array([0] * 10, dtype=np.uint8) |
| 74 | + arr[:] = bytes(b"1234567890") |
| 75 | + zeroize.zeroize_np(arr) |
| 76 | + self.assertEqual(True, all(arr == 0)) |
56 | 77 |
|
57 |
| - arr = np.array([0] * 10, dtype=np.uint8) |
58 |
| - arr[:] = bytes(b"1234567890") |
59 |
| - zeroize.zeroize_np(arr) |
60 |
| - self.assertEqual(True, all(arr == 0)) |
61 |
| - |
62 |
| - unlock_memory() |
| 78 | + finally: |
| 79 | + unlock_memory(arr) |
63 | 80 |
|
64 | 81 | def test_zeroize1_sizes(self):
|
65 |
| - # lock_memory() |
66 |
| - |
67 | 82 | for size in SIZES_MB:
|
68 |
| - arr = bytearray(int(size * 1024 * 1024)) |
69 |
| - zeroize.zeroize1(arr) |
70 |
| - self.assertEqual(arr, bytearray(int(size * 1024 * 1024))) |
71 |
| - |
72 |
| - # unlock_memory() |
| 83 | + try: |
| 84 | + arr = bytearray(int(size * 1024 * 1024)) |
| 85 | + zeroize.zeroize1(arr) |
| 86 | + self.assertEqual(arr, bytearray(int(size * 1024 * 1024))) |
73 | 87 |
|
74 |
| - def test_zeroize_np_sizes(self): |
75 |
| - # lock_memory() |
| 88 | + finally: |
| 89 | + unlock_memory(arr) |
76 | 90 |
|
| 91 | + def test_zeroize_np_sizes(self): |
77 | 92 | for size in [size for size in SIZES_MB if size < 4]:
|
78 |
| - array_size = int(size * 1024 * 1024) |
79 |
| - random_array = np.random.randint(0, 256, array_size, dtype=np.uint8) |
80 |
| - zeroize.zeroize_np(random_array) |
81 |
| - self.assertEqual(True, all(random_array == 0)) |
82 |
| - |
83 |
| - # unlock_memory() |
| 93 | + try: |
| 94 | + array_size = int(size * 1024 * 1024) |
| 95 | + random_array = np.random.randint(0, 256, array_size, dtype=np.uint8) |
| 96 | + zeroize.zeroize_np(random_array) |
| 97 | + self.assertEqual(True, all(random_array == 0)) |
| 98 | + finally: |
| 99 | + unlock_memory(random_array) |
84 | 100 |
|
85 | 101 |
|
86 | 102 | if __name__ == "__main__":
|
|
0 commit comments