Skip to content

Commit d529169

Browse files
authored
Merge pull request #96 from hatoo/fix-error-handling
Improve error handling and remove panic-prone code
2 parents 3a6fe92 + dc62b78 commit d529169

File tree

3 files changed

+68
-33
lines changed

3 files changed

+68
-33
lines changed

src/default_client.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,12 @@ impl DefaultClient {
227227
#[cfg(feature = "rustls-client")]
228228
let tls = self
229229
.tls_connector(http_version)
230-
.connect(host.to_string().try_into().expect("Invalid host"), tcp)
230+
.connect(
231+
host.to_string()
232+
.try_into()
233+
.map_err(|_| Error::InvalidHost(uri.clone()))?,
234+
tcp,
235+
)
231236
.await
232237
.map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
233238

@@ -293,10 +298,18 @@ where
293298
SendRequest::Http1(sender) => {
294299
if req.version() == hyper::Version::HTTP_2 {
295300
if let Some(authority) = req.uri().authority().cloned() {
296-
req.headers_mut().insert(
297-
header::HOST,
298-
authority.as_str().parse().expect("Invalid authority"),
299-
);
301+
match authority.as_str().parse::<header::HeaderValue>() {
302+
Ok(host_value) => {
303+
req.headers_mut().insert(header::HOST, host_value);
304+
}
305+
Err(err) => {
306+
tracing::warn!(
307+
"Failed to parse authority '{}' as HOST header: {}",
308+
authority,
309+
err
310+
);
311+
}
312+
}
300313
}
301314
}
302315
remove_authority(&mut req);

src/lib.rs

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,12 @@ where
7777

7878
Ok(async move {
7979
loop {
80-
let Ok((stream, _)) = listener.accept().await else {
81-
continue;
80+
let (stream, _) = match listener.accept().await {
81+
Ok(conn) => conn,
82+
Err(err) => {
83+
tracing::warn!("Failed to accept connection: {}", err);
84+
continue;
85+
}
8286
};
8387

8488
let service = service.clone();
@@ -140,12 +144,16 @@ where
140144
};
141145

142146
tokio::spawn(async move {
143-
let Ok(client) = hyper::upgrade::on(req).await else {
144-
tracing::error!(
145-
"Bad CONNECT request: {}, Reason: Invalid Upgrade",
146-
connect_authority
147-
);
148-
return;
147+
let client = match hyper::upgrade::on(req).await {
148+
Ok(client) => client,
149+
Err(err) => {
150+
tracing::error!(
151+
"Failed to upgrade CONNECT request for {}: {}",
152+
connect_authority,
153+
err
154+
);
155+
return;
156+
}
149157
};
150158
if let Some(server_config) =
151159
proxy.server_config(connect_authority.host().to_string(), true)
@@ -196,17 +204,22 @@ where
196204
.await
197205
};
198206

199-
if let Err(_err) = res {
200-
// Suppress error because if we serving HTTPS proxy server and forward to HTTPS server, it will always error when closing connection.
201-
// tracing::error!("Error in proxy: {}", err);
207+
if let Err(err) = res {
208+
tracing::debug!("Connection closed: {}", err);
202209
}
203210
} else {
204-
let Ok(mut server) =
205-
TcpStream::connect(connect_authority.as_str()).await
206-
else {
207-
tracing::error!("Failed to connect to {}", connect_authority);
208-
return;
209-
};
211+
let mut server =
212+
match TcpStream::connect(connect_authority.as_str()).await {
213+
Ok(server) => server,
214+
Err(err) => {
215+
tracing::error!(
216+
"Failed to connect to {}: {}",
217+
connect_authority,
218+
err
219+
);
220+
return;
221+
}
222+
};
210223
let _ = tokio::io::copy_bidirectional(
211224
&mut TokioIo::new(client),
212225
&mut server,
@@ -229,13 +242,21 @@ where
229242
}
230243

231244
fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
232-
self.root_cert.as_ref().map(|root_cert| {
245+
self.root_cert.as_ref().and_then(|root_cert| {
233246
if let Some(cache) = self.cert_cache.as_ref() {
234-
cache.get_with(host.clone(), move || {
247+
Some(cache.get_with(host.clone(), move || {
235248
generate_cert(host, root_cert.borrow())
236-
})
249+
.map_err(|err| {
250+
tracing::error!("Failed to generate certificate for host: {}", err);
251+
})
252+
.unwrap()
253+
}))
237254
} else {
238255
generate_cert(host, root_cert.borrow())
256+
.map_err(|err| {
257+
tracing::error!("Failed to generate certificate: {}", err);
258+
})
259+
.ok()
239260
}
240261
})
241262
}

src/tls.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ pub struct CertifiedKeyDer {
55
pub key_der: Vec<u8>,
66
}
77

8-
pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> CertifiedKeyDer {
9-
let mut cert_params = rcgen::CertificateParams::new(vec![host.clone()]).unwrap();
8+
pub fn generate_cert(
9+
host: String,
10+
root_cert: &rcgen::CertifiedKey,
11+
) -> Result<CertifiedKeyDer, rcgen::Error> {
12+
let mut cert_params = rcgen::CertificateParams::new(vec![host.clone()])?;
1013
cert_params
1114
.key_usages
1215
.push(rcgen::KeyUsagePurpose::DigitalSignature);
@@ -22,14 +25,12 @@ pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> Certified
2225
dn
2326
};
2427

25-
let key_pair = rcgen::KeyPair::generate().unwrap();
28+
let key_pair = rcgen::KeyPair::generate()?;
2629

27-
let cert = cert_params
28-
.signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair)
29-
.unwrap();
30+
let cert = cert_params.signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair)?;
3031

31-
CertifiedKeyDer {
32+
Ok(CertifiedKeyDer {
3233
cert_der: cert.der().to_vec(),
3334
key_der: key_pair.serialize_der(),
34-
}
35+
})
3536
}

0 commit comments

Comments
 (0)