Skip to content

Commit 0ac9fdf

Browse files
authored
RUST-226 Support tlsCertificateKeyFilePassword (#1256)
1 parent e3df089 commit 0ac9fdf

File tree

9 files changed

+109
-14
lines changed

9 files changed

+109
-14
lines changed

.evergreen/MSRV-Cargo.toml.diff

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
141c141
1+
116a117
2+
> url = "=2.5.2"
3+
144c145
24
< version = "1.17.0"
35
---
46
> version = "=1.38.0"
5-
150c150
7+
153c154
68
< version = "0.7.0"
79
---
810
> version = "=0.7.11"

.evergreen/run-tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -o pipefail
66
source .evergreen/env.sh
77
source .evergreen/cargo-test.sh
88

9-
FEATURE_FLAGS+=("tracing-unstable")
9+
FEATURE_FLAGS+=("tracing-unstable" "cert-key-password")
1010

1111
if [ "$OPENSSL" = true ]; then
1212
FEATURE_FLAGS+=("openssl-tls")

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ sync = []
3434
rustls-tls = ["dep:rustls", "dep:rustls-pemfile", "dep:tokio-rustls"]
3535
openssl-tls = ["dep:openssl", "dep:openssl-probe", "dep:tokio-openssl"]
3636
dns-resolver = ["dep:hickory-resolver", "dep:hickory-proto"]
37+
cert-key-password = ["dep:pem", "dep:pkcs8"]
3738

3839
# Enable support for MONGODB-AWS authentication.
3940
# This can only be used with the tokio-runtime feature flag.
@@ -95,7 +96,9 @@ mongodb-internal-macros = { path = "macros", version = "3.1.0" }
9596
num_cpus = { version = "1.13.1", optional = true }
9697
openssl = { version = "0.10.38", optional = true }
9798
openssl-probe = { version = "0.1.5", optional = true }
99+
pem = { version = "3.0.4", optional = true }
98100
percent-encoding = "2.0.0"
101+
pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"], optional = true }
99102
rand = { version = "0.8.3", features = ["small_rng"] }
100103
rayon = { version = "1.5.3", optional = true }
101104
rustc_version_runtime = "0.3.0"

src/client/options.rs

+30
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,10 @@ pub struct TlsOptions {
10471047
/// The default value is to error on invalid hostnames.
10481048
#[cfg(feature = "openssl-tls")]
10491049
pub allow_invalid_hostnames: Option<bool>,
1050+
1051+
/// If set, the key in `cert_key_file_path` must be encrypted with this password.
1052+
#[cfg(feature = "cert-key-password")]
1053+
pub tls_certificate_key_file_password: Option<Vec<u8>>,
10501054
}
10511055

10521056
impl TlsOptions {
@@ -1064,6 +1068,8 @@ impl TlsOptions {
10641068
tlscafile: Option<&'a str>,
10651069
tlscertificatekeyfile: Option<&'a str>,
10661070
tlsallowinvalidcertificates: Option<bool>,
1071+
#[cfg(feature = "cert-key-password")]
1072+
tlscertificatekeyfilepassword: Option<&'a str>,
10671073
}
10681074

10691075
let state = TlsOptionsHelper {
@@ -1077,6 +1083,11 @@ impl TlsOptions {
10771083
.as_ref()
10781084
.map(|s| s.to_str().unwrap()),
10791085
tlsallowinvalidcertificates: tls_options.allow_invalid_certificates,
1086+
#[cfg(feature = "cert-key-password")]
1087+
tlscertificatekeyfilepassword: tls_options
1088+
.tls_certificate_key_file_password
1089+
.as_deref()
1090+
.map(|b| std::str::from_utf8(b).unwrap()),
10801091
};
10811092
state.serialize(serializer)
10821093
}
@@ -2126,6 +2137,25 @@ impl ConnectionString {
21262137
))
21272138
}
21282139
},
2140+
#[cfg(feature = "cert-key-password")]
2141+
"tlscertificatekeyfilepassword" => match &mut self.tls {
2142+
Some(Tls::Disabled) => {
2143+
return Err(ErrorKind::InvalidArgument {
2144+
message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(),
2145+
}
2146+
.into());
2147+
}
2148+
Some(Tls::Enabled(options)) => {
2149+
options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec());
2150+
}
2151+
None => {
2152+
self.tls = Some(Tls::Enabled(
2153+
TlsOptions::builder()
2154+
.tls_certificate_key_file_password(value.as_bytes().to_vec())
2155+
.build(),
2156+
))
2157+
}
2158+
},
21292159
"uuidrepresentation" => match value.to_lowercase().as_str() {
21302160
"csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy),
21312161
"javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy),

src/client/options/test.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
2020
"tlsInsecure is parsed correctly",
2121
// The driver does not support maxPoolSize=0
2222
"maxPoolSize=0 does not error",
23-
// TODO RUST-226: unskip this test
23+
#[cfg(not(feature = "cert-key-password"))]
2424
"Valid tlsCertificateKeyFilePassword is parsed correctly",
2525
];
2626

src/runtime.rs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ mod acknowledged_message;
88
))]
99
mod http;
1010
mod join_handle;
11+
#[cfg(feature = "cert-key-password")]
12+
mod pem;
1113
#[cfg(any(feature = "in-use-encryption", test))]
1214
pub(crate) mod process;
1315
#[cfg(feature = "dns-resolver")]

src/runtime/pem.rs

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use crate::error::{ErrorKind, Result};
2+
3+
pub(crate) fn decrypt_private_key(pem_data: &[u8], password: &[u8]) -> Result<Vec<u8>> {
4+
let pems = pem::parse_many(pem_data).map_err(|error| ErrorKind::InvalidTlsConfig {
5+
message: format!("Could not parse pemfile: {}", error),
6+
})?;
7+
let mut iter = pems
8+
.into_iter()
9+
.filter(|pem| pem.tag() == "ENCRYPTED PRIVATE KEY");
10+
let encrypted_bytes = match iter.next() {
11+
Some(pem) => pem.into_contents(),
12+
None => {
13+
return Err(ErrorKind::InvalidTlsConfig {
14+
message: "No encrypted private keys found".into(),
15+
}
16+
.into())
17+
}
18+
};
19+
let encrypted_key = pkcs8::EncryptedPrivateKeyInfo::try_from(encrypted_bytes.as_slice())
20+
.map_err(|error| ErrorKind::InvalidTlsConfig {
21+
message: format!("Invalid encrypted private key: {}", error),
22+
})?;
23+
let decrypted_key =
24+
encrypted_key
25+
.decrypt(password)
26+
.map_err(|error| ErrorKind::InvalidTlsConfig {
27+
message: format!("Failed to decrypt private key: {}", error),
28+
})?;
29+
Ok(decrypted_key.as_bytes().to_vec())
30+
}

src/runtime/tls_openssl.rs

+31-10
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ impl TlsConfig {
3131
None => true,
3232
};
3333

34-
let connector = make_openssl_connector(options).map_err(|e| {
35-
Error::from(ErrorKind::InvalidTlsConfig {
36-
message: e.to_string(),
37-
})
38-
})?;
34+
let connector = make_openssl_connector(options)?;
3935

4036
Ok(TlsConfig {
4137
connector,
@@ -66,25 +62,50 @@ pub(super) async fn tls_connect(
6662
Ok(stream)
6763
}
6864

69-
fn make_openssl_connector(cfg: TlsOptions) -> std::result::Result<SslConnector, ErrorStack> {
70-
let mut builder = SslConnector::builder(SslMethod::tls_client())?;
65+
fn make_openssl_connector(cfg: TlsOptions) -> Result<SslConnector> {
66+
let openssl_err = |e: ErrorStack| {
67+
Error::from(ErrorKind::InvalidTlsConfig {
68+
message: e.to_string(),
69+
})
70+
};
71+
72+
let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?;
7173

7274
let TlsOptions {
7375
allow_invalid_certificates,
7476
ca_file_path,
7577
cert_key_file_path,
7678
allow_invalid_hostnames: _,
79+
#[cfg(feature = "cert-key-password")]
80+
tls_certificate_key_file_password,
7781
} = cfg;
7882

7983
if let Some(true) = allow_invalid_certificates {
8084
builder.set_verify(SslVerifyMode::NONE);
8185
}
8286
if let Some(path) = ca_file_path {
83-
builder.set_ca_file(path)?;
87+
builder.set_ca_file(path).map_err(openssl_err)?;
8488
}
8589
if let Some(path) = cert_key_file_path {
86-
builder.set_certificate_file(path.clone(), SslFiletype::PEM)?;
87-
builder.set_private_key_file(path, SslFiletype::PEM)?;
90+
builder
91+
.set_certificate_file(path.clone(), SslFiletype::PEM)
92+
.map_err(openssl_err)?;
93+
// Inner fn so the cert-key-password path can early return
94+
let handle_private_key = || -> Result<()> {
95+
#[cfg(feature = "cert-key-password")]
96+
if let Some(key_pw) = tls_certificate_key_file_password {
97+
let contents = std::fs::read(&path)?;
98+
let key_bytes = super::pem::decrypt_private_key(&contents, &key_pw)?;
99+
let key =
100+
openssl::pkey::PKey::private_key_from_der(&key_bytes).map_err(openssl_err)?;
101+
builder.set_private_key(&key).map_err(openssl_err)?;
102+
return Ok(());
103+
}
104+
builder
105+
.set_private_key_file(path, SslFiletype::PEM)
106+
.map_err(openssl_err)
107+
};
108+
handle_private_key()?;
88109
}
89110

90111
Ok(builder.build())

src/runtime/tls_rustls.rs

+7
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ fn make_rustls_config(cfg: TlsOptions) -> Result<rustls::ClientConfig> {
104104

105105
file.rewind()?;
106106
let key = loop {
107+
#[cfg(feature = "cert-key-password")]
108+
if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() {
109+
use std::io::Read;
110+
let mut contents = vec![];
111+
file.read_to_end(&mut contents)?;
112+
break rustls::PrivateKey(super::pem::decrypt_private_key(&contents, key_pw)?);
113+
}
107114
match read_one(&mut file) {
108115
Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => {
109116
break rustls::PrivateKey(bytes)

0 commit comments

Comments
 (0)