diff --git a/CHANGES.md b/CHANGES.md index e1aec45..847f7d8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,7 +10,7 @@ public_key)` - When operations fail (i.e., `OQS_SUCCESS != 0`) in functions returning non-boolean objects, a `RuntimeError` is now raised, instead of returning 0 -- Bugfix on Linux, `c_int` -> `c_size_t` for buffer sizes +- Bugfix on Linux platforms, `c_int` -> `c_size_t` for buffer sizes - Pyright type checking fixes - Updated examples to use `ML-KEM` and `ML-DSA` as the defaults diff --git a/oqs/oqs.py b/oqs/oqs.py index e57ce89..f75aa38 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -110,10 +110,10 @@ def _install_liboqs(target_directory, oqs_version=None): def _load_liboqs(): if "OQS_INSTALL_PATH" in os.environ: oqs_install_dir = os.path.abspath(os.environ["OQS_INSTALL_PATH"]) - else: + else: home_dir = os.path.expanduser("~") oqs_install_dir = os.path.abspath(home_dir + os.path.sep + "_oqs") # $HOME/_oqs - + oqs_lib_dir = ( os.path.abspath(oqs_install_dir + os.path.sep + "bin") # $HOME/_oqs/bin if platform.system() == "Windows" @@ -122,10 +122,14 @@ def _load_liboqs(): oqs_lib64_dir = ( os.path.abspath(oqs_install_dir + os.path.sep + "bin") # $HOME/_oqs/bin if platform.system() == "Windows" - else os.path.abspath(oqs_install_dir + os.path.sep + "lib64") # $HOME/_oqs/lib64 + else os.path.abspath( + oqs_install_dir + os.path.sep + "lib64" + ) # $HOME/_oqs/lib64 ) try: - _liboqs = _load_shared_obj(name="oqs", additional_searching_paths=[oqs_lib_dir, oqs_lib64_dir]) + _liboqs = _load_shared_obj( + name="oqs", additional_searching_paths=[oqs_lib_dir, oqs_lib64_dir] + ) assert _liboqs except RuntimeError: # We don't have liboqs, so we try to install it automatically @@ -462,18 +466,18 @@ def sign(self, message): c_signature = ct.create_string_buffer(self._sig.contents.length_signature) # Initialize to maximum signature size - signature_len = ct.c_size_t(self._sig.contents.length_signature) + c_signature_len = ct.c_size_t(self._sig.contents.length_signature) rv = native().OQS_SIG_sign( self._sig, ct.byref(c_signature), - ct.byref(signature_len), + ct.byref(c_signature_len), c_message, c_message_len, self.secret_key, ) if rv == OQS_SUCCESS: - return bytes(c_signature[: signature_len.value]) + return bytes(c_signature[: c_signature_len.value]) else: raise RuntimeError("Can not sign message") @@ -489,7 +493,7 @@ def verify(self, message, signature, public_key): c_message = ct.create_string_buffer(message, len(message)) c_message_len = ct.c_size_t(len(c_message)) c_signature = ct.create_string_buffer(signature, len(signature)) - signature_len = ct.c_size_t(len(c_signature)) + c_signature_len = ct.c_size_t(len(c_signature)) c_public_key = ct.create_string_buffer( public_key, self._sig.contents.length_public_key ) @@ -499,7 +503,7 @@ def verify(self, message, signature, public_key): c_message, c_message_len, c_signature, - signature_len, + c_signature_len, c_public_key, ) return True if rv == OQS_SUCCESS else False @@ -511,16 +515,22 @@ def sign_with_ctx_str(self, message, context): :param context: the context string. :param message: the message to sign. """ + if context and not self._sig.contents.sig_with_ctx_support: + raise RuntimeError("Signing with context string not supported") + # Provide length to avoid extra null char c_message = ct.create_string_buffer(message, len(message)) c_message_len = ct.c_size_t(len(c_message)) - c_context = ct.create_string_buffer(context, len(context)) - context_len = ct.c_size_t(len(c_context)) + if len(context) == 0: + c_context = None + c_context_len = 0 + else: + c_context = ct.create_string_buffer(context, len(context)) + c_context_len = ct.c_size_t(len(c_context)) c_signature = ct.create_string_buffer(self._sig.contents.length_signature) # Initialize to maximum signature size c_signature_len = ct.c_size_t(self._sig.contents.length_signature) - rv = native().OQS_SIG_sign_with_ctx_str( self._sig, ct.byref(c_signature), @@ -528,7 +538,7 @@ def sign_with_ctx_str(self, message, context): c_message, c_message_len, c_context, - context_len, + c_context_len, self.secret_key, ) if rv == OQS_SUCCESS: @@ -545,13 +555,20 @@ def verify_with_ctx_str(self, message, signature, context, public_key): :param context: the context string. :param public_key: the signer's public key. """ + if context and not self._sig.contents.sig_with_ctx_support: + raise RuntimeError("Verifying with context string not supported") + # Provide length to avoid extra null char c_message = ct.create_string_buffer(message, len(message)) c_message_len = ct.c_size_t(len(c_message)) c_signature = ct.create_string_buffer(signature, len(signature)) c_signature_len = ct.c_size_t(len(c_signature)) - c_context = ct.create_string_buffer(context, len(context)) - c_context_len = ct.c_size_t(len(c_context)) + if len(context) == 0: + c_context = None + c_context_len = 0 + else: + c_context = ct.create_string_buffer(context, len(context)) + c_context_len = ct.c_size_t(len(c_context)) c_public_key = ct.create_string_buffer( public_key, self._sig.contents.length_public_key ) diff --git a/tests/test_kem.py b/tests/test_kem.py index 7733666..ca4fa52 100644 --- a/tests/test_kem.py +++ b/tests/test_kem.py @@ -6,7 +6,7 @@ disabled_KEM_patterns = [] if platform.system() == "Windows": - disabled_KEM_patterns = ["Classic-McEliece"] + disabled_KEM_patterns = [""] def test_correctness(): @@ -47,7 +47,7 @@ def check_wrong_ciphertext(alg_name): def test_not_supported(): try: - with oqs.KeyEncapsulation("bogus"): + with oqs.KeyEncapsulation("unsupported_sig"): raise AssertionError("oqs.MechanismNotSupportedError was not raised.") except oqs.MechanismNotSupportedError: pass @@ -56,7 +56,6 @@ def test_not_supported(): def test_not_enabled(): - # TODO: test broken as the compiled lib determines which algorithms are supported and enabled for alg_name in oqs.get_supported_kem_mechanisms(): if alg_name not in oqs.get_enabled_kem_mechanisms(): # Found a non-enabled but supported alg diff --git a/tests/test_sig.py b/tests/test_sig.py index f84ef4b..57c406a 100644 --- a/tests/test_sig.py +++ b/tests/test_sig.py @@ -92,7 +92,7 @@ def check_wrong_public_key(alg_name): def test_not_supported(): try: - with oqs.Signature("bogus"): + with oqs.Signature("unsupported_sig"): raise AssertionError("oqs.MechanismNotSupportedError was not raised.") except oqs.MechanismNotSupportedError: pass @@ -101,7 +101,6 @@ def test_not_supported(): def test_not_enabled(): - # TODO: test broken as the compiled lib determines which algorithms are supported and enabled for alg_name in oqs.get_supported_sig_mechanisms(): if alg_name not in oqs.get_enabled_sig_mechanisms(): # Found a non-enabled but supported alg