diff --git a/main.py b/main.py index 682ca4b..83e99a6 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,43 @@ import zeroize import numpy as np +import ctypes +# Lock memory using ctypes +def lock_memory(): + libc = ctypes.CDLL("libc.so.6") + # Lock all current and future pages from being swapped out + libc.mlockall(ctypes.c_int(0x02 | 0x04)) # MCL_CURRENT | MCL_FUTURE + + +def unlock_memory(): + libc = ctypes.CDLL("libc.so.6") + # Unlock all locked pages + libc.munlockall() + + +print("locking memory") +lock_memory() + +print("allocate memory") + # regular array -arr = bytearray(b'1234567890') +arr = bytearray(b"1234567890") # numpy array arr_np = np.array([0] * 10, dtype=np.uint8) arr_np[:] = arr -assert arr_np.tobytes() == b'1234567890' +assert arr_np.tobytes() == b"1234567890" print("zeroize'ing...: ") zeroize.zeroize1(arr) zeroize.zeroize_np(arr_np) -print("checking if is zeroized...") -assert arr == bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') +print("checking if is zeroized") +assert arr == bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") assert all(arr_np == 0) +print("unlocking memory") +unlock_memory() + print("all good, bye!") diff --git a/src/lib.rs b/src/lib.rs index 0a922e0..8083c37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use numpy::{PyArray1, PyArrayMethods}; +use pyo3::buffer::PyBuffer; use pyo3::prelude::*; -use pyo3::types::PyByteArray; +use pyo3::types::{PyByteArray, PyMemoryView}; use zeroize_rs::Zeroize; /// A Python module implemented in Rust. @@ -8,6 +9,7 @@ use zeroize_rs::Zeroize; fn zeroize(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(zeroize1, m)?)?; m.add_function(wrap_pyfunction!(zeroize_np, m)?)?; + // m.add_function(wrap_pyfunction!(zeroize_mv, m)?)?; Ok(()) } @@ -23,3 +25,19 @@ fn zeroize_np<'py>(arr: &Bound<'py, PyArray1>) -> PyResult<()> { unsafe { arr.as_slice_mut().unwrap().zeroize(); } Ok(()) } + +// #[pyfunction] +// fn zeroize_mv<'py>(arr: &PyMemoryView, len: usize) -> PyResult<()> { +// // Get the buffer information +// let buffer = PyBuffer::::get(arr)?; +// +// // Get the raw mutable pointer and length of the memory view +// let ptr = arr.as_ptr() as *mut u8; +// +// // Create a mutable slice from the raw pointer and length +// let arr: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(ptr, len) }; +// +// arr.zeroize(); +// +// Ok(()) +// }