diff --git a/src/frame/headers.rs b/src/frame/headers.rs index e9b163e5..bb36cd4f 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -554,32 +554,38 @@ impl Pseudo { pub fn request(method: Method, uri: Uri, protocol: Option) -> Self { let parts = uri::Parts::from(uri); - let mut path = parts - .path_and_query - .map(|v| BytesStr::from(v.as_str())) - .unwrap_or(BytesStr::from_static("")); - - match method { - Method::OPTIONS | Method::CONNECT => {} - _ if path.is_empty() => { - path = BytesStr::from_static("/"); - } - _ => {} - } + let (scheme, path) = if method == Method::CONNECT && protocol.is_none() { + (None, None) + } else { + let path = parts + .path_and_query + .map(|v| BytesStr::from(v.as_str())) + .unwrap_or(BytesStr::from_static("")); + + let path = if !path.is_empty() { + path + } else { + if method == Method::OPTIONS { + BytesStr::from_static("*") + } else { + BytesStr::from_static("/") + } + }; + + (parts.scheme, Some(path)) + }; let mut pseudo = Pseudo { method: Some(method), scheme: None, authority: None, - path: Some(path).filter(|p| !p.is_empty()), + path, protocol, status: None, }; // If the URI includes a scheme component, add it to the pseudo headers - // - // TODO: Scheme must be set... - if let Some(scheme) = parts.scheme { + if let Some(scheme) = scheme { pseudo.set_scheme(scheme); } @@ -1048,4 +1054,161 @@ mod test { let mut buf = BytesMut::new(); huffman::decode(src, &mut buf).unwrap() } + + #[test] + fn test_connect_request_pseudo_headers_omits_path_and_scheme() { + // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields + // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5 + + assert_eq!( + Pseudo::request( + Method::CONNECT, + Uri::from_static("https://example.com:8443"), + None + ), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com:8443").into(), + ..Default::default() + } + ); + + assert_eq!( + Pseudo::request( + Method::CONNECT, + Uri::from_static("https://example.com/test"), + None + ), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com").into(), + ..Default::default() + } + ); + + assert_eq!( + Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com:8443").into(), + ..Default::default() + } + ); + } + + #[test] + fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() { + // On requests that contain the :protocol pseudo-header field, the + // :scheme and :path pseudo-header fields of the target URI (see + // Section 5) MUST also be included. + // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + + assert_eq!( + Pseudo::request( + Method::CONNECT, + Uri::from_static("https://example.com:8443"), + Protocol::from_static("the-bread-protocol").into() + ), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com:8443").into(), + scheme: BytesStr::from_static("https").into(), + path: BytesStr::from_static("/").into(), + protocol: Protocol::from_static("the-bread-protocol").into(), + ..Default::default() + } + ); + + assert_eq!( + Pseudo::request( + Method::CONNECT, + Uri::from_static("https://example.com:8443/test"), + Protocol::from_static("the-bread-protocol").into() + ), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com:8443").into(), + scheme: BytesStr::from_static("https").into(), + path: BytesStr::from_static("/test").into(), + protocol: Protocol::from_static("the-bread-protocol").into(), + ..Default::default() + } + ); + + assert_eq!( + Pseudo::request( + Method::CONNECT, + Uri::from_static("http://example.com/a/b/c"), + Protocol::from_static("the-bread-protocol").into() + ), + Pseudo { + method: Method::CONNECT.into(), + authority: BytesStr::from_static("example.com").into(), + scheme: BytesStr::from_static("http").into(), + path: BytesStr::from_static("/a/b/c").into(), + protocol: Protocol::from_static("the-bread-protocol").into(), + ..Default::default() + } + ); + } + + #[test] + fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() { + // an OPTIONS request for an "http" or "https" URI that does not include a path component; + // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]). + // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1 + assert_eq!( + Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,), + Pseudo { + method: Method::OPTIONS.into(), + authority: BytesStr::from_static("example.com:8080").into(), + path: BytesStr::from_static("*").into(), + ..Default::default() + } + ); + } + + #[test] + fn test_non_option_and_non_connect_requests_include_path_and_scheme() { + let methods = [ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::HEAD, + Method::PATCH, + Method::TRACE, + ]; + + for method in methods { + assert_eq!( + Pseudo::request( + method.clone(), + Uri::from_static("http://example.com:8080"), + None, + ), + Pseudo { + method: method.clone().into(), + authority: BytesStr::from_static("example.com:8080").into(), + scheme: BytesStr::from_static("http").into(), + path: BytesStr::from_static("/").into(), + ..Default::default() + } + ); + assert_eq!( + Pseudo::request( + method.clone(), + Uri::from_static("https://example.com/a/b/c"), + None, + ), + Pseudo { + method: method.into(), + authority: BytesStr::from_static("example.com").into(), + scheme: BytesStr::from_static("https").into(), + path: BytesStr::from_static("/a/b/c").into(), + ..Default::default() + } + ); + } + } } diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index 858bf770..ba123c0a 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -4,10 +4,7 @@ use std::fmt; use bytes::Bytes; use http::{HeaderMap, StatusCode}; -use h2::{ - ext::Protocol, - frame::{self, Frame, StreamId}, -}; +use h2::frame::{self, Frame, StreamId}; pub const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; pub const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; @@ -124,19 +121,24 @@ impl Mock { M::Error: fmt::Debug, { let method = method.try_into().unwrap(); - let (id, _, fields) = self.into_parts(); + let (id, pseudo, fields) = self.into_parts(); let frame = frame::Headers::new( id, frame::Pseudo { - scheme: None, method: Some(method), - ..Default::default() + ..pseudo }, fields, ); Mock(frame) } + pub fn pseudo(self, pseudo: frame::Pseudo) -> Self { + let (id, _, fields) = self.into_parts(); + let frame = frame::Headers::new(id, pseudo, fields); + Mock(frame) + } + pub fn response(self, status: S) -> Self where S: TryInto, @@ -184,15 +186,6 @@ impl Mock { Mock(frame::Headers::new(id, pseudo, fields)) } - pub fn protocol(self, value: &str) -> Self { - let (id, mut pseudo, fields) = self.into_parts(); - let value = Protocol::from(value); - - pseudo.set_protocol(value); - - Mock(frame::Headers::new(id, pseudo, fields)) - } - pub fn eos(mut self) -> Self { self.0.set_end_stream(); self diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 7bd223e3..9cc2f91e 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -585,6 +585,45 @@ async fn http_2_request_without_scheme_or_authority() { join(srv, h2).await; } +#[tokio::test] +async fn http_2_connect_request_omit_scheme_and_path_fields() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .pseudo(frame::Pseudo { + method: Method::CONNECT.into(), + authority: util::byte_str("tunnel.example.com:8443").into(), + ..Default::default() + }) + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + + // In HTTP_2 CONNECT request the ":scheme" and ":path" pseudo-header fields MUST be omitted. + let request = Request::builder() + .version(Version::HTTP_2) + .method(Method::CONNECT) + .uri("https://tunnel.example.com:8443/") + .body(()) + .unwrap(); + + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + #[test] #[ignore] fn request_with_h1_version() {} @@ -1521,8 +1560,14 @@ async fn extended_connect_request() { srv.recv_frame( frames::headers(1) - .request("CONNECT", "http://bread/baguette") - .protocol("the-bread-protocol") + .pseudo(frame::Pseudo { + method: Method::CONNECT.into(), + scheme: util::byte_str("http").into(), + authority: util::byte_str("bread").into(), + path: util::byte_str("/baguette").into(), + protocol: Protocol::from_static("the-bread-protocol").into(), + ..Default::default() + }) .eos(), ) .await; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index a4b983a0..7155b586 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1252,11 +1252,11 @@ async fn extended_connect_protocol_disabled_by_default() { assert_eq!(settings.is_extended_connect_protocol_enabled(), None); client - .send_frame( - frames::headers(1) - .request("CONNECT", "http://bread/baguette") - .protocol("the-bread-protocol"), - ) + .send_frame(frames::headers(1).pseudo(frame::Pseudo::request( + Method::CONNECT, + uri::Uri::from_static("http://bread/baguette"), + Protocol::from_static("the-bread-protocol").into(), + ))) .await; client.recv_frame(frames::reset(1).protocol_error()).await; @@ -1285,11 +1285,11 @@ async fn extended_connect_protocol_enabled_during_handshake() { assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); client - .send_frame( - frames::headers(1) - .request("CONNECT", "http://bread/baguette") - .protocol("the-bread-protocol"), - ) + .send_frame(frames::headers(1).pseudo(frame::Pseudo::request( + Method::CONNECT, + uri::Uri::from_static("http://bread/baguette"), + Protocol::from_static("the-bread-protocol").into(), + ))) .await; client.recv_frame(frames::headers(1).response(200)).await; @@ -1332,11 +1332,11 @@ async fn reject_pseudo_protocol_on_non_connect_request() { assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); client - .send_frame( - frames::headers(1) - .request("GET", "http://bread/baguette") - .protocol("the-bread-protocol"), - ) + .send_frame(frames::headers(1).pseudo(frame::Pseudo::request( + Method::GET, + uri::Uri::from_static("http://bread/baguette"), + Some(Protocol::from_static("the-bread-protocol")), + ))) .await; client.recv_frame(frames::reset(1).protocol_error()).await; @@ -1360,7 +1360,7 @@ async fn reject_pseudo_protocol_on_non_connect_request() { } #[tokio::test] -async fn reject_authority_target_on_extended_connect_request() { +async fn reject_extended_connect_request_without_scheme() { h2_support::trace_init!(); let (io, mut client) = mock::new(); @@ -1371,11 +1371,12 @@ async fn reject_authority_target_on_extended_connect_request() { assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); client - .send_frame( - frames::headers(1) - .request("CONNECT", "bread:80") - .protocol("the-bread-protocol"), - ) + .send_frame(frames::headers(1).pseudo(frame::Pseudo { + method: Method::CONNECT.into(), + path: util::byte_str("/").into(), + protocol: Protocol::from("the-bread-protocol").into(), + ..Default::default() + })) .await; client.recv_frame(frames::reset(1).protocol_error()).await; @@ -1399,7 +1400,7 @@ async fn reject_authority_target_on_extended_connect_request() { } #[tokio::test] -async fn reject_non_authority_target_on_connect_request() { +async fn reject_extended_connect_request_without_path() { h2_support::trace_init!(); let (io, mut client) = mock::new(); @@ -1410,7 +1411,12 @@ async fn reject_non_authority_target_on_connect_request() { assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); client - .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) + .send_frame(frames::headers(1).pseudo(frame::Pseudo { + method: Method::CONNECT.into(), + scheme: util::byte_str("https").into(), + protocol: Protocol::from("the-bread-protocol").into(), + ..Default::default() + })) .await; client.recv_frame(frames::reset(1).protocol_error()).await;