@@ -8,20 +8,33 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
88use std:: task:: { Context , Poll } ;
99use tokio:: { net:: TcpStream , task:: JoinHandle } ;
1010
11+ #[ cfg( all( feature = "native-tls-client" , feature = "rustls-client" ) ) ]
12+ compile_error ! ( "feature \" native-tls-client\" and feature \" rustls-client\" cannot be enabled at the same time" ) ;
13+
14+ #[ cfg( all( not( feature = "native-tls-client" ) , not( feature = "rustls-client" ) ) ) ]
15+ compile_error ! ( "feature \" native-tls-client\" or feature \" rustls-client\" must be enabled" ) ;
16+
1117#[ derive( thiserror:: Error , Debug ) ]
1218pub enum Error {
1319 #[ error( "{0} doesn't have an valid host" ) ]
1420 InvalidHost ( Uri ) ,
1521 #[ error( transparent) ]
1622 IoError ( #[ from] std:: io:: Error ) ,
1723 #[ error( transparent) ]
18- NativeTlsError ( #[ from] tokio_native_tls:: native_tls:: Error ) ,
19- #[ error( transparent) ]
2024 HyperError ( #[ from] hyper:: Error ) ,
2125 #[ error( "Failed to connect to {0}, {1}" ) ]
2226 ConnectError ( Uri , hyper:: Error ) ,
27+
28+ #[ cfg( feature = "native-tls-client" ) ]
2329 #[ error( "Failed to connect with TLS to {0}, {1}" ) ]
2430 TlsConnectError ( Uri , native_tls:: Error ) ,
31+ #[ cfg( feature = "native-tls-client" ) ]
32+ #[ error( transparent) ]
33+ NativeTlsError ( #[ from] tokio_native_tls:: native_tls:: Error ) ,
34+
35+ #[ cfg( feature = "rustls-client" ) ]
36+ #[ error( "Failed to connect with TLS to {0}, {1}" ) ]
37+ TlsConnectError ( Uri , std:: io:: Error ) ,
2538}
2639
2740/// Upgraded connections
@@ -34,24 +47,60 @@ pub struct Upgraded {
3447#[ derive( Clone ) ]
3548/// Default HTTP client for this crate
3649pub struct DefaultClient {
50+ #[ cfg( feature = "native-tls-client" ) ]
3751 tls_connector_no_alpn : tokio_native_tls:: TlsConnector ,
52+ #[ cfg( feature = "native-tls-client" ) ]
3853 tls_connector_alpn_h2 : tokio_native_tls:: TlsConnector ,
54+
55+ #[ cfg( feature = "rustls-client" ) ]
56+ tls_connector_no_alpn : tokio_rustls:: TlsConnector ,
57+ #[ cfg( feature = "rustls-client" ) ]
58+ tls_connector_alpn_h2 : tokio_rustls:: TlsConnector ,
59+
3960 /// If true, send_request will returns an Upgraded struct when the response is an upgrade
4061 /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
4162 pub with_upgrades : bool ,
4263}
4364impl DefaultClient {
44- pub fn new ( ) -> native_tls:: Result < Self > {
45- let tls_connector_no_alpn = native_tls:: TlsConnector :: builder ( ) . build ( ) ?;
65+ #[ cfg( feature = "native-tls-client" ) ]
66+ pub fn new ( ) -> Self {
67+ let tls_connector_no_alpn = native_tls:: TlsConnector :: builder ( ) . build ( ) . unwrap ( ) ;
4668 let tls_connector_alpn_h2 = native_tls:: TlsConnector :: builder ( )
4769 . request_alpns ( & [ "h2" , "http/1.1" ] )
48- . build ( ) ?;
70+ . build ( )
71+ . unwrap ( ) ;
4972
50- Ok ( Self {
73+ Self {
5174 tls_connector_no_alpn : tokio_native_tls:: TlsConnector :: from ( tls_connector_no_alpn) ,
5275 tls_connector_alpn_h2 : tokio_native_tls:: TlsConnector :: from ( tls_connector_alpn_h2) ,
5376 with_upgrades : false ,
54- } )
77+ }
78+ }
79+
80+ #[ cfg( feature = "rustls-client" ) ]
81+ pub fn new ( ) -> Self {
82+ use std:: sync:: Arc ;
83+
84+ let mut root_cert_store = tokio_rustls:: rustls:: RootCertStore :: empty ( ) ;
85+ root_cert_store. extend ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
86+
87+ let tls_connector_no_alpn = tokio_rustls:: rustls:: ClientConfig :: builder ( )
88+ . with_root_certificates ( root_cert_store. clone ( ) )
89+ . with_no_client_auth ( ) ;
90+ let mut tls_connector_alpn_h2 = tokio_rustls:: rustls:: ClientConfig :: builder ( )
91+ . with_root_certificates ( root_cert_store. clone ( ) )
92+ . with_no_client_auth ( ) ;
93+ tls_connector_alpn_h2. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
94+
95+ Self {
96+ tls_connector_no_alpn : tokio_rustls:: TlsConnector :: from ( Arc :: new (
97+ tls_connector_no_alpn,
98+ ) ) ,
99+ tls_connector_alpn_h2 : tokio_rustls:: TlsConnector :: from ( Arc :: new (
100+ tls_connector_alpn_h2,
101+ ) ) ,
102+ with_upgrades : false ,
103+ }
55104 }
56105
57106 /// Enable HTTP upgrades
@@ -61,13 +110,22 @@ impl DefaultClient {
61110 self
62111 }
63112
113+ #[ cfg( feature = "native-tls-client" ) ]
64114 fn tls_connector ( & self , http_version : Version ) -> & tokio_native_tls:: TlsConnector {
65115 match http_version {
66116 Version :: HTTP_2 => & self . tls_connector_alpn_h2 ,
67117 _ => & self . tls_connector_no_alpn ,
68118 }
69119 }
70120
121+ #[ cfg( feature = "rustls-client" ) ]
122+ fn tls_connector ( & self , http_version : Version ) -> & tokio_rustls:: TlsConnector {
123+ match http_version {
124+ Version :: HTTP_2 => & self . tls_connector_alpn_h2 ,
125+ _ => & self . tls_connector_no_alpn ,
126+ }
127+ }
128+
71129 /// Send a request and return a response.
72130 /// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
73131 /// Request should have a full URL including scheme.
@@ -152,17 +210,31 @@ impl DefaultClient {
152210 let _ = tcp. set_nodelay ( true ) ;
153211
154212 if uri. scheme ( ) == Some ( & hyper:: http:: uri:: Scheme :: HTTPS ) {
213+ #[ cfg( feature = "native-tls-client" ) ]
155214 let tls = self
156215 . tls_connector ( http_version)
157216 . connect ( host, tcp)
158217 . await
159218 . map_err ( |err| Error :: TlsConnectError ( uri. clone ( ) , err) ) ?;
219+ #[ cfg( feature = "rustls-client" ) ]
220+ let tls = self
221+ . tls_connector ( http_version)
222+ . connect ( host. to_string ( ) . try_into ( ) . expect ( "Invalid host" ) , tcp)
223+ . await
224+ . map_err ( |err| Error :: TlsConnectError ( uri. clone ( ) , err) ) ?;
225+
226+ #[ cfg( feature = "native-tls-client" ) ]
227+ let is_h2 = matches ! (
228+ tls. get_ref( )
229+ . negotiated_alpn( )
230+ . map( |a| a. map( |b| b == b"h2" ) ) ,
231+ Ok ( Some ( true ) )
232+ ) ;
233+
234+ #[ cfg( feature = "rustls-client" ) ]
235+ let is_h2 = tls. get_ref ( ) . 1 . alpn_protocol ( ) == Some ( b"h2" ) ;
160236
161- if let Ok ( Some ( true ) ) = tls
162- . get_ref ( )
163- . negotiated_alpn ( )
164- . map ( |a| a. map ( |b| b == b"h2" ) )
165- {
237+ if is_h2 {
166238 let ( sender, conn) = client:: conn:: http2:: Builder :: new ( TokioExecutor :: new ( ) )
167239 . handshake ( TokioIo :: new ( tls) )
168240 . await
0 commit comments