Skip to content

Commit 9b400b1

Browse files
authored
Extending partner to accept key encryption algo and pass that down.
* Fix github action build fail due to: https://stackoverflow.com/questions/71673404/importerror-cannot-import-name-unicodefun-from-click * Added partner setting to force canonicalize binary. * Formatted with black * #62 * Asserting error messages and _encrypted_data_with_faulty_key_algo
1 parent 499fd03 commit 9b400b1

File tree

5 files changed

+171
-45
lines changed

5 files changed

+171
-45
lines changed

pyas2lib/as2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DIGEST_ALGORITHMS,
2626
EDIINT_FEATURES,
2727
ENCRYPTION_ALGORITHMS,
28+
KEY_ENCRYPTION_ALGORITHMS,
2829
MDN_CONFIRM_TEXT,
2930
MDN_FAILED_TEXT,
3031
MDN_MODES,
@@ -183,6 +184,9 @@ class Partner:
183184
:param sign_alg: The signing algorithm to be used for generating the
184185
signature. (default `rsassa_pkcs1v15`)
185186
187+
:param key_enc_alg: The key encryption algorithm to be used.
188+
(default `rsaes_pkcs1v15`)
189+
186190
"""
187191

188192
as2_name: str
@@ -202,6 +206,7 @@ class Partner:
202206
ignore_self_signed: bool = True
203207
canonicalize_as_binary: bool = False
204208
sign_alg: str = "rsassa_pkcs1v15"
209+
key_enc_alg: str = "rsaes_pkcs1v15"
205210

206211
def __post_init__(self):
207212
"""Run the post initialisation checks for this class."""
@@ -236,6 +241,12 @@ def __post_init__(self):
236241
f"must be one of {SIGNATUR_ALGORITHMS}"
237242
)
238243

244+
if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS:
245+
raise ImproperlyConfigured(
246+
f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, "
247+
f"must be one of {KEY_ENCRYPTION_ALGORITHMS}"
248+
)
249+
239250
def load_verify_cert(self):
240251
"""Load the verification certificate of the partner and returned the parsed cert."""
241252
if self.validate_certs:

pyas2lib/cms.py

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,15 @@ def decompress_message(compressed_data):
6565
raise DecompressionError("Decompression failed with cause: {}".format(e)) from e
6666

6767

68-
def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
68+
def encrypt_message(
69+
data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15"
70+
):
6971
"""Function encrypts data and returns the generated ASN.1
7072
7173
:param data_to_encrypt: A byte string of the data to be encrypted
7274
:param enc_alg: The algorithm to be used for encrypting the data
7375
:param encryption_cert: The certificate to be used for encrypting the data
76+
:param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep
7477
7578
:return: A CMS ASN.1 byte string of the encrypted data.
7679
"""
@@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
136139
raise AS2Exception("Unsupported Encryption Algorithm")
137140

138141
# Encrypt the key and build the ASN.1 message
139-
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
142+
if key_enc_alg == "rsaes_pkcs1v15":
143+
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
144+
elif key_enc_alg == "rsaes_oaep":
145+
encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key)
146+
else:
147+
raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}")
140148

141149
return cms.ContentInfo(
142150
{
@@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
163171
}
164172
),
165173
"key_encryption_algorithm": cms.KeyEncryptionAlgorithm(
166-
{"algorithm": cms.KeyEncryptionAlgorithmId("rsa")}
174+
{
175+
"algorithm": cms.KeyEncryptionAlgorithmId(
176+
key_enc_alg
177+
)
178+
}
167179
),
168180
"encrypted_key": cms.OctetString(encrypted_key),
169181
}
@@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key):
199211
key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native
200212
encrypted_key = recipient_info["encrypted_key"].native
201213

202-
if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
203-
"rsa"
204-
):
205-
try:
214+
try:
215+
if cms.KeyEncryptionAlgorithmId(
216+
key_enc_alg
217+
) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"):
206218
key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key)
207-
except Exception as e:
208-
raise DecryptionError(
209-
"Failed to decrypt the payload: Could not extract decryption key."
210-
) from e
211-
212-
alg = cms_content["content"]["encrypted_content_info"][
213-
"content_encryption_algorithm"
214-
]
215-
encapsulated_data = cms_content["content"]["encrypted_content_info"][
216-
"encrypted_content"
217-
].native
218219

219-
try:
220-
if alg["algorithm"].native == "rc4":
221-
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
222-
elif alg.encryption_cipher == "tripledes":
223-
cipher = "tripledes_192_cbc"
224-
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
225-
key, encapsulated_data, alg.encryption_iv
226-
)
227-
elif alg.encryption_cipher == "aes":
228-
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
229-
key, encapsulated_data, alg.encryption_iv
230-
)
231-
elif alg.encryption_cipher == "rc2":
232-
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
233-
key, encapsulated_data, alg["parameters"]["iv"].native
234-
)
235-
else:
236-
raise AS2Exception("Unsupported Encryption Algorithm")
237-
except Exception as e:
238-
raise DecryptionError(
239-
"Failed to decrypt the payload: {}".format(e)
240-
) from e
241-
else:
242-
raise AS2Exception("Unsupported Encryption Algorithm")
220+
elif cms.KeyEncryptionAlgorithmId(
221+
key_enc_alg
222+
) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"):
223+
key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key)
224+
else:
225+
raise AS2Exception(
226+
f"Unsupported Key Encryption Algorithm {key_enc_alg}"
227+
)
228+
except Exception as e:
229+
raise DecryptionError(
230+
"Failed to decrypt the payload: Could not extract decryption key."
231+
) from e
232+
233+
alg = cms_content["content"]["encrypted_content_info"][
234+
"content_encryption_algorithm"
235+
]
236+
encapsulated_data = cms_content["content"]["encrypted_content_info"][
237+
"encrypted_content"
238+
].native
239+
240+
try:
241+
if alg["algorithm"].native == "rc4":
242+
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
243+
elif alg.encryption_cipher == "tripledes":
244+
cipher = "tripledes_192_cbc"
245+
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
246+
key, encapsulated_data, alg.encryption_iv
247+
)
248+
elif alg.encryption_cipher == "aes":
249+
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
250+
key, encapsulated_data, alg.encryption_iv
251+
)
252+
elif alg.encryption_cipher == "rc2":
253+
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
254+
key, encapsulated_data, alg["parameters"]["iv"].native
255+
)
256+
else:
257+
raise AS2Exception("Unsupported Encryption Algorithm")
258+
except Exception as e:
259+
raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e
243260
else:
244261
raise DecryptionError("Encrypted data not found in ASN.1 ")
245262

pyas2lib/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,7 @@
3232
"rsassa_pkcs1v15",
3333
"rsassa_pss",
3434
)
35+
KEY_ENCRYPTION_ALGORITHMS = (
36+
"rsaes_pkcs1v15",
37+
"rsaes_oaep",
38+
)

pyas2lib/tests/test_advanced.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ def test_partner_checks(self):
337337
with self.assertRaises(ImproperlyConfigured):
338338
as2.Partner("a partner", sign_alg="xyz")
339339

340+
with self.assertRaises(ImproperlyConfigured):
341+
as2.Partner("a partner", key_enc_alg="xyz")
342+
340343
def test_message_checks(self):
341344
"""Test the checks and other features of Message."""
342345
msg = as2.Message()

pyas2lib/tests/test_cms.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import os
33

44
import pytest
5-
from oscrypto import asymmetric
5+
from oscrypto import asymmetric, symmetric, util
6+
7+
from asn1crypto import algos, cms as crypto_cms, core
68

79
from pyas2lib.as2 import Organization
810
from pyas2lib import cms
@@ -22,6 +24,68 @@
2224
).dump()
2325

2426

27+
def _encrypted_data_with_faulty_key_algo():
28+
with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
29+
encrypt_cert = asymmetric.load_certificate(fp.read())
30+
enc_alg_list = "rc4_128_cbc".split("_")
31+
cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2]
32+
key = util.rand_bytes(int(key_length) // 8)
33+
algorithm_id = "1.2.840.113549.3.4"
34+
encrypted_content = symmetric.rc4_encrypt(key, b"data")
35+
enc_alg_asn1 = algos.EncryptionAlgorithm(
36+
{
37+
"algorithm": algorithm_id,
38+
}
39+
)
40+
encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key)
41+
return crypto_cms.ContentInfo(
42+
{
43+
"content_type": crypto_cms.ContentType("enveloped_data"),
44+
"content": crypto_cms.EnvelopedData(
45+
{
46+
"version": crypto_cms.CMSVersion("v0"),
47+
"recipient_infos": [
48+
crypto_cms.KeyTransRecipientInfo(
49+
{
50+
"version": crypto_cms.CMSVersion("v0"),
51+
"rid": crypto_cms.RecipientIdentifier(
52+
{
53+
"issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber(
54+
{
55+
"issuer": encrypt_cert.asn1[
56+
"tbs_certificate"
57+
]["issuer"],
58+
"serial_number": encrypt_cert.asn1[
59+
"tbs_certificate"
60+
]["serial_number"],
61+
}
62+
)
63+
}
64+
),
65+
"key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm(
66+
{
67+
"algorithm": crypto_cms.KeyEncryptionAlgorithmId(
68+
"aes128_wrap"
69+
)
70+
}
71+
),
72+
"encrypted_key": crypto_cms.OctetString(encrypted_key),
73+
}
74+
)
75+
],
76+
"encrypted_content_info": crypto_cms.EncryptedContentInfo(
77+
{
78+
"content_type": crypto_cms.ContentType("data"),
79+
"content_encryption_algorithm": enc_alg_asn1,
80+
"encrypted_content": encrypted_content,
81+
}
82+
),
83+
}
84+
),
85+
}
86+
).dump()
87+
88+
2589
def test_compress():
2690
"""Test the compression and decompression functions."""
2791
compressed_data = cms.compress_message(b"data")
@@ -87,9 +151,22 @@ def test_encryption():
87151
"aes_128_cbc",
88152
"aes_192_cbc",
89153
"aes_256_cbc",
154+
"tripledes_192_cbc",
155+
]
156+
157+
key_enc_algos = [
158+
"rsaes_oaep",
159+
"rsaes_pkcs1v15",
160+
]
161+
162+
encryption_algos = [
163+
(alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos
90164
]
91-
for enc_algorithm in enc_algorithms:
92-
encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert)
165+
166+
for enc_algorithm, encryption_scheme in encryption_algos:
167+
encrypted_data = cms.encrypt_message(
168+
b"data", enc_algorithm, encrypt_cert, encryption_scheme
169+
)
93170
_, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key)
94171
assert decrypted_data == b"data"
95172

@@ -101,3 +178,17 @@ def test_encryption():
101178
encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert)
102179
with pytest.raises(AS2Exception):
103180
cms.decrypt_message(encrypted_data, decrypt_key)
181+
182+
# Test faulty key encryption algorithm
183+
with pytest.raises(
184+
AS2Exception, match="Unsupported Key Encryption Scheme: des_64_cbc"
185+
):
186+
cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")
187+
188+
# Test unsupported key encryption algorithm
189+
encrypted_data = _encrypted_data_with_faulty_key_algo()
190+
with pytest.raises(
191+
AS2Exception,
192+
match="Failed to decrypt the payload: Could not extract decryption key.",
193+
):
194+
cms.decrypt_message(encrypted_data, decrypt_key)

0 commit comments

Comments
 (0)