Skip to content

Commit ce78967

Browse files
authored
Merge pull request #83 from hatoo/rustls-client
Rustls client
2 parents af80824 + 01b291c commit ce78967

File tree

11 files changed

+116
-20
lines changed

11 files changed

+116
-20
lines changed

.github/workflows/rust.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ jobs:
2020
run: cargo build --verbose
2121
- name: Run tests
2222
run: cargo test --verbose
23+
- name: Run tests --no-default-features --features rustls-client
24+
run: cargo test --verbose --no-default-features --features rustls-client

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ resolver = "2"
1414

1515
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1616

17+
[features]
18+
default = ["native-tls-client"]
19+
native-tls-client = ["dep:native-tls", "dep:tokio-native-tls"]
20+
21+
# You can use --no-default-feature --features "rustls-client" to remove native-tls from dependencies
22+
rustls-client = ["dep:webpki-roots"]
23+
1724
[dependencies]
1825
tokio = { version = "1.39.3", features = [
1926
"macros",
@@ -29,13 +36,16 @@ bytes = "1.7.1"
2936
http-body-util = "0.1.0"
3037
rcgen = "0.13.1"
3138
tokio-rustls = "0.26.1"
32-
tokio-native-tls = "0.3.1"
3339
tracing = "0.1.40"
3440
hyper-util = { version = "0.1.7", features = ["tokio"] }
35-
native-tls = { version = "0.2.12", features = ["alpn"] }
3641
thiserror = "2.0.11"
3742
moka = { version = "0.12.8", features = ["sync"] }
3843

44+
native-tls = { version = "0.2.12", features = ["alpn"], optional = true }
45+
tokio-native-tls = { version = "0.3.1", optional = true }
46+
47+
webpki-roots = { version = "0.26.8", optional = true }
48+
3949
[dev-dependencies]
4050
axum = { version = "0.8.1", features = ["http2"] }
4151
clap = { version = "4.5.16", features = ["derive"] }

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async fn main() {
8585
Some(Cache::new(128)),
8686
);
8787
88-
let client = DefaultClient::new().unwrap();
88+
let client = DefaultClient::new();
8989
let server = proxy
9090
.bind(
9191
("127.0.0.1", 3003),

examples/dev_proxy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async fn main() {
8282
Some(Cache::new(128)),
8383
);
8484

85-
let client = DefaultClient::new().unwrap();
85+
let client = DefaultClient::new();
8686
let proxy = proxy
8787
.bind(
8888
("127.0.0.1", 3003),

examples/https.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async fn main() {
9292
);
9393
let proxy = Arc::new(proxy);
9494

95-
let client = DefaultClient::new().unwrap();
95+
let client = DefaultClient::new();
9696

9797
let listener = TcpListener::bind(("127.0.0.1", 3003)).await.unwrap();
9898

examples/proxy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async fn main() {
7272
Some(Cache::new(128)),
7373
);
7474

75-
let client = DefaultClient::new().unwrap();
75+
let client = DefaultClient::new();
7676
let server = proxy
7777
.bind(
7878
("127.0.0.1", 3003),

examples/websocket.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async fn main() {
7878
Some(Cache::new(128)),
7979
);
8080

81-
let client = DefaultClient::new().unwrap().with_upgrades();
81+
let client = DefaultClient::new().with_upgrades();
8282
let server = proxy
8383
.bind(
8484
("127.0.0.1", 3003),

src/default_client.rs

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,33 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
88
use std::task::{Context, Poll};
99
use tokio::{net::TcpStream, task::JoinHandle};
1010

11+
#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
12+
compile_error!("feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time");
13+
14+
#[cfg(all(not(feature = "native-tls-client"), not(feature = "rustls-client")))]
15+
compile_error!("feature \"native-tls-client\" or feature \"rustls-client\" must be enabled");
16+
1117
#[derive(thiserror::Error, Debug)]
1218
pub enum Error {
1319
#[error("{0} doesn't have an valid host")]
1420
InvalidHost(Uri),
1521
#[error(transparent)]
1622
IoError(#[from] std::io::Error),
1723
#[error(transparent)]
18-
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
19-
#[error(transparent)]
2024
HyperError(#[from] hyper::Error),
2125
#[error("Failed to connect to {0}, {1}")]
2226
ConnectError(Uri, hyper::Error),
27+
28+
#[cfg(feature = "native-tls-client")]
2329
#[error("Failed to connect with TLS to {0}, {1}")]
2430
TlsConnectError(Uri, native_tls::Error),
31+
#[cfg(feature = "native-tls-client")]
32+
#[error(transparent)]
33+
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
34+
35+
#[cfg(feature = "rustls-client")]
36+
#[error("Failed to connect with TLS to {0}, {1}")]
37+
TlsConnectError(Uri, std::io::Error),
2538
}
2639

2740
/// Upgraded connections
@@ -34,24 +47,60 @@ pub struct Upgraded {
3447
#[derive(Clone)]
3548
/// Default HTTP client for this crate
3649
pub struct DefaultClient {
50+
#[cfg(feature = "native-tls-client")]
3751
tls_connector_no_alpn: tokio_native_tls::TlsConnector,
52+
#[cfg(feature = "native-tls-client")]
3853
tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
54+
55+
#[cfg(feature = "rustls-client")]
56+
tls_connector_no_alpn: tokio_rustls::TlsConnector,
57+
#[cfg(feature = "rustls-client")]
58+
tls_connector_alpn_h2: tokio_rustls::TlsConnector,
59+
3960
/// If true, send_request will returns an Upgraded struct when the response is an upgrade
4061
/// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
4162
pub with_upgrades: bool,
4263
}
4364
impl DefaultClient {
44-
pub fn new() -> native_tls::Result<Self> {
45-
let tls_connector_no_alpn = native_tls::TlsConnector::builder().build()?;
65+
#[cfg(feature = "native-tls-client")]
66+
pub fn new() -> Self {
67+
let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
4668
let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
4769
.request_alpns(&["h2", "http/1.1"])
48-
.build()?;
70+
.build()
71+
.unwrap();
4972

50-
Ok(Self {
73+
Self {
5174
tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
5275
tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
5376
with_upgrades: false,
54-
})
77+
}
78+
}
79+
80+
#[cfg(feature = "rustls-client")]
81+
pub fn new() -> Self {
82+
use std::sync::Arc;
83+
84+
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
85+
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
86+
87+
let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
88+
.with_root_certificates(root_cert_store.clone())
89+
.with_no_client_auth();
90+
let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
91+
.with_root_certificates(root_cert_store.clone())
92+
.with_no_client_auth();
93+
tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
94+
95+
Self {
96+
tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
97+
tls_connector_no_alpn,
98+
)),
99+
tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
100+
tls_connector_alpn_h2,
101+
)),
102+
with_upgrades: false,
103+
}
55104
}
56105

57106
/// Enable HTTP upgrades
@@ -61,13 +110,22 @@ impl DefaultClient {
61110
self
62111
}
63112

113+
#[cfg(feature = "native-tls-client")]
64114
fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
65115
match http_version {
66116
Version::HTTP_2 => &self.tls_connector_alpn_h2,
67117
_ => &self.tls_connector_no_alpn,
68118
}
69119
}
70120

121+
#[cfg(feature = "rustls-client")]
122+
fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
123+
match http_version {
124+
Version::HTTP_2 => &self.tls_connector_alpn_h2,
125+
_ => &self.tls_connector_no_alpn,
126+
}
127+
}
128+
71129
/// Send a request and return a response.
72130
/// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
73131
/// Request should have a full URL including scheme.
@@ -152,17 +210,31 @@ impl DefaultClient {
152210
let _ = tcp.set_nodelay(true);
153211

154212
if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
213+
#[cfg(feature = "native-tls-client")]
155214
let tls = self
156215
.tls_connector(http_version)
157216
.connect(host, tcp)
158217
.await
159218
.map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
219+
#[cfg(feature = "rustls-client")]
220+
let tls = self
221+
.tls_connector(http_version)
222+
.connect(host.to_string().try_into().expect("Invalid host"), tcp)
223+
.await
224+
.map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
225+
226+
#[cfg(feature = "native-tls-client")]
227+
let is_h2 = matches!(
228+
tls.get_ref()
229+
.negotiated_alpn()
230+
.map(|a| a.map(|b| b == b"h2")),
231+
Ok(Some(true))
232+
);
233+
234+
#[cfg(feature = "rustls-client")]
235+
let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
160236

161-
if let Ok(Some(true)) = tls
162-
.get_ref()
163-
.negotiated_alpn()
164-
.map(|a| a.map(|b| b == b"h2"))
165-
{
237+
if is_h2 {
166238
let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
167239
.handshake(TokioIo::new(tls))
168240
.await

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use tokio_rustls::rustls;
1717
pub use futures;
1818
pub use hyper;
1919
pub use moka;
20+
21+
#[cfg(feature = "native-tls-client")]
2022
pub use tokio_native_tls;
2123

2224
pub mod default_client;

0 commit comments

Comments
 (0)