diff --git a/examples/lock_and_zeroize.py b/examples/lock_and_zeroize.py index 6302ad2..b224440 100644 --- a/examples/lock_and_zeroize.py +++ b/examples/lock_and_zeroize.py @@ -30,34 +30,35 @@ def unlock_memory(buffer): if MUNLOCK(address, size) != 0: raise RuntimeError("Failed to unlock memory") - -try: - print("allocate memory") - - # regular array - arr = bytearray(b"1234567890") - - # numpy array - arr_np = np.array([0] * 10, dtype=np.uint8) - arr_np[:] = arr - assert arr_np.tobytes() == b"1234567890" - - print("locking memory") - - lock_memory(arr) - lock_memory(arr_np) - - print("zeroize'ing...: ") - zeroize1(arr) - zeroize_np(arr_np) - - print("checking if is zeroized") - assert arr == bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - assert all(arr_np == 0) - - print("all good, bye!") -finally: - # Unlock the memory - print("unlocking memory") - unlock_memory(arr) - unlock_memory(arr_np) +if __name__ == "__main__": + try: + print("allocate memory") + + # regular array + arr = bytearray(b"1234567890") + + # numpy array + arr_np = np.array([0] * 10, dtype=np.uint8) + arr_np[:] = arr + assert arr_np.tobytes() == b"1234567890" + + print("locking memory") + + lock_memory(arr) + lock_memory(arr_np) + + print("zeroize'ing...: ") + zeroize1(arr) + zeroize_np(arr_np) + + print("checking if is zeroized") + assert arr == bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + assert all(arr_np == 0) + + print("all good, bye!") + + finally: + # Unlock the memory + print("unlocking memory") + unlock_memory(arr) + unlock_memory(arr_np) diff --git a/examples/zeroize_before_fork.py b/examples/zeroize_before_fork.py index ccfcdb4..cd9fb9d 100644 --- a/examples/zeroize_before_fork.py +++ b/examples/zeroize_before_fork.py @@ -31,24 +31,26 @@ def unlock_memory(buffer): raise RuntimeError("Failed to unlock memory") -try: - sensitive_data = bytearray(b"Sensitive Information") - lock_memory(sensitive_data) - - print("Before zeroization:", sensitive_data) - - zeroize1(sensitive_data) - print("After zeroization:", sensitive_data) - - # Forking after zeroization to ensure no sensitive data is copied - pid = os.fork() - if pid == 0: - # This is the child process - print("Child process memory after fork:", sensitive_data) - else: - # This is the parent process - os.wait() # Wait for the child process to exit -finally: - # Unlock the memory - print("unlocking memory") - unlock_memory(sensitive_data) +if __name__ == "__main__": + try: + sensitive_data = bytearray(b"Sensitive Information") + lock_memory(sensitive_data) + + print("Before zeroization:", sensitive_data) + + zeroize1(sensitive_data) + print("After zeroization:", sensitive_data) + + # Forking after zeroization to ensure no sensitive data is copied + pid = os.fork() + if pid == 0: + # This is the child process + print("Child process memory after fork:", sensitive_data) + else: + # This is the parent process + os.wait() # Wait for the child process to exit + + finally: + # Unlock the memory + print("unlocking memory") + unlock_memory(sensitive_data) diff --git a/tests/test_zeroize.py b/tests/test_zeroize.py index 59b5f26..e49137f 100644 --- a/tests/test_zeroize.py +++ b/tests/test_zeroize.py @@ -1,21 +1,35 @@ -import ctypes -import os import unittest import zeroize import numpy as np -# Lock memory using ctypes -def lock_memory(): - libc = ctypes.CDLL("libc.so.6") - # Lock all current and future pages from being swapped out - libc.mlockall(ctypes.c_int(0x02 | 0x04)) # MCL_CURRENT | MCL_FUTURE +import ctypes + + +# Load the C standard library +LIBC = ctypes.CDLL("libc.so.6") +MLOCK = LIBC.mlock +MUNLOCK = LIBC.munlock + +# Define mlock and munlock argument types +MLOCK.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +MUNLOCK.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + + +def lock_memory(buffer): + """Locks the memory of the given buffer.""" + address = ctypes.addressof(ctypes.c_char.from_buffer(buffer)) + size = len(buffer) + if MLOCK(address, size) != 0: + raise RuntimeError("Failed to lock memory") -def unlock_memory(): - libc = ctypes.CDLL("libc.so.6") - # Unlock all locked pages - libc.munlockall() +def unlock_memory(buffer): + """Unlocks the memory of the given buffer.""" + address = ctypes.addressof(ctypes.c_char.from_buffer(buffer)) + size = len(buffer) + if MUNLOCK(address, size) != 0: + raise RuntimeError("Failed to unlock memory") SIZES_MB = [ @@ -43,44 +57,46 @@ def unlock_memory(): class TestStringMethods(unittest.TestCase): def test_zeroize1(self): - lock_memory() - - arr = bytearray(b"1234567890") - zeroize.zeroize1(arr) - self.assertEqual(arr, bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")) + try: + arr = bytearray(b"1234567890") + lock_memory(arr) + zeroize.zeroize1(arr) + self.assertEqual( + arr, bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ) - unlock_memory() + finally: + unlock_memory(arr) def test_zeroize_np(self): - lock_memory() + try: + arr = np.array([0] * 10, dtype=np.uint8) + arr[:] = bytes(b"1234567890") + zeroize.zeroize_np(arr) + self.assertEqual(True, all(arr == 0)) - arr = np.array([0] * 10, dtype=np.uint8) - arr[:] = bytes(b"1234567890") - zeroize.zeroize_np(arr) - self.assertEqual(True, all(arr == 0)) - - unlock_memory() + finally: + unlock_memory(arr) def test_zeroize1_sizes(self): - # lock_memory() - for size in SIZES_MB: - arr = bytearray(int(size * 1024 * 1024)) - zeroize.zeroize1(arr) - self.assertEqual(arr, bytearray(int(size * 1024 * 1024))) - - # unlock_memory() + try: + arr = bytearray(int(size * 1024 * 1024)) + zeroize.zeroize1(arr) + self.assertEqual(arr, bytearray(int(size * 1024 * 1024))) - def test_zeroize_np_sizes(self): - # lock_memory() + finally: + unlock_memory(arr) + def test_zeroize_np_sizes(self): for size in [size for size in SIZES_MB if size < 4]: - array_size = int(size * 1024 * 1024) - random_array = np.random.randint(0, 256, array_size, dtype=np.uint8) - zeroize.zeroize_np(random_array) - self.assertEqual(True, all(random_array == 0)) - - # unlock_memory() + try: + array_size = int(size * 1024 * 1024) + random_array = np.random.randint(0, 256, array_size, dtype=np.uint8) + zeroize.zeroize_np(random_array) + self.assertEqual(True, all(random_array == 0)) + finally: + unlock_memory(random_array) if __name__ == "__main__":