From 62c835090ed16be790044c832616725a8e6c79d7 Mon Sep 17 00:00:00 2001 From: Aber Date: Wed, 22 Nov 2023 15:19:04 +0800 Subject: [PATCH] Initial version (#1) --- .github/workflows/ci.yml | 21 + .gitignore | 1 + Cargo.lock | 61 + Cargo.toml | 13 + LICENSE | 2 +- src/_abnf.rs | 45 + src/_connection.rs | 657 ++++++++ src/_events.rs | 141 ++ src/_headers.rs | 594 ++++++++ src/_readers.rs | 692 +++++++++ src/_receivebuffer.rs | 230 +++ src/_state.rs | 701 +++++++++ src/_util.rs | 110 ++ src/_writers.rs | 351 +++++ src/lib.rs | 15 + tests/connections.rs | 3133 ++++++++++++++++++++++++++++++++++++++ tests/helper.rs | 201 +++ 17 files changed, 6967 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 src/_abnf.rs create mode 100644 src/_connection.rs create mode 100644 src/_events.rs create mode 100644 src/_headers.rs create mode 100644 src/_readers.rs create mode 100644 src/_receivebuffer.rs create mode 100644 src/_state.rs create mode 100644 src/_util.rs create mode 100644 src/_writers.rs create mode 100644 src/lib.rs create mode 100644 tests/connections.rs create mode 100644 tests/helper.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..19a899c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + tests: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..8782846 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,61 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "h11" +version = "0.1.0" +dependencies = [ + "lazy_static", + "regex", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + +[[package]] +name = "regex" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..75f77f9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "h11" +version = "0.1.0" +edition = "2021" +description = "A pure-Rust, bring-your-own-I/O implementation of HTTP/1.1" +license = "Apache-2.0" +license-file = "LICENSE" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +lazy_static = "1.4.0" +regex = "1.10.2" diff --git a/LICENSE b/LICENSE index 261eeb9..bd7067c 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2023 abersheeran Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/src/_abnf.rs b/src/_abnf.rs new file mode 100644 index 0000000..93847da --- /dev/null +++ b/src/_abnf.rs @@ -0,0 +1,45 @@ +use lazy_static::lazy_static; + +pub static OWS: &str = r"[ \t]*"; +pub static TOKEN: &str = r"[-!#$%&'*+.^_`|~0-9a-zA-Z]+"; +pub static FIELD_NAME: &str = TOKEN; +pub static VCHAR: &str = r"[\x21-\x7e]"; +pub static VCHAR_OR_OBS_TEXT: &str = r"[^\x00\s]"; +pub static FIELD_VCHAR: &str = VCHAR_OR_OBS_TEXT; + +lazy_static! { + pub static ref FIELD_CONTENT: String = format!(r"{}+(?:[ \t]+{}+)*", FIELD_VCHAR, FIELD_VCHAR); + pub static ref FIELD_VALUE: String = format!(r"({})?", *FIELD_CONTENT); + pub static ref HEADER_FIELD: String = format!( + r"(?P{field_name}):{OWS}(?P{field_value}){OWS}", + field_name = FIELD_NAME, + field_value = *FIELD_VALUE, + OWS = OWS + ); + pub static ref METHOD: String = TOKEN.to_string(); + pub static ref REQUEST_TARGET: String = format!("{}+", VCHAR); + pub static ref HTTP_VERSION: String = r"HTTP/(?P[0-9]\.[0-9])".to_string(); + pub static ref REQUEST_LINE: String = format!( + r"(?P{method}) (?P{request_target}) {http_version}", + method = *METHOD, + request_target = *REQUEST_TARGET, + http_version = *HTTP_VERSION + ); + pub static ref STATUS_CODE: String = r"[0-9]{3}".to_string(); + pub static ref REASON_PHRASE: String = format!(r"([ \t]|{})*", VCHAR_OR_OBS_TEXT); + pub static ref STATUS_LINE: String = format!( + r"{http_version} (?P{status_code})(?: (?P{reason_phrase}))?", + http_version = *HTTP_VERSION, + status_code = *STATUS_CODE, + reason_phrase = *REASON_PHRASE + ); + pub static ref HEXDIG: String = r"[0-9A-Fa-f]".to_string(); + pub static ref CHUNK_SIZE: String = format!(r"({}){{1,20}}", *HEXDIG); + pub static ref CHUNK_EXT: String = ";.*".to_string(); + pub static ref CHUNK_HEADER: String = format!( + r"(?P{chunk_size})(?P{chunk_ext})?{OWS}\r\n", + chunk_size = *CHUNK_SIZE, + chunk_ext = *CHUNK_EXT, + OWS = OWS + ); +} diff --git a/src/_connection.rs b/src/_connection.rs new file mode 100644 index 0000000..bba5360 --- /dev/null +++ b/src/_connection.rs @@ -0,0 +1,657 @@ +use crate::_events::*; +use crate::_headers::*; +use crate::_readers::*; +use crate::_receivebuffer::*; +use crate::_state::*; +use crate::_util::*; +use crate::_writers::*; +use std::collections::HashMap; +use std::collections::HashSet; + +static DEFAULT_MAX_INCOMPLETE_EVENT_SIZE: usize = 16 * 1024; + +enum RequestOrResponse { + Request(Request), + Response(Response), +} + +impl RequestOrResponse { + pub fn headers(&self) -> &Headers { + match self { + Self::Request(request) => &request.headers, + Self::Response(response) => &response.headers, + } + } + + pub fn http_version(&self) -> &Vec { + match self { + Self::Request(request) => &request.http_version, + Self::Response(response) => &response.http_version, + } + } +} + +impl From for RequestOrResponse { + fn from(value: Request) -> Self { + Self::Request(value) + } +} + +impl From for RequestOrResponse { + fn from(value: Response) -> Self { + Self::Response(value) + } +} + +impl From for RequestOrResponse { + fn from(value: Event) -> Self { + match value { + Event::Request(request) => Self::Request(request), + Event::NormalResponse(response) => Self::Response(response), + _ => panic!("Invalid event type"), + } + } +} + +fn _keep_alive>(event: T) -> bool { + let event: RequestOrResponse = event.into(); + let connection = get_comma_header(event.headers(), b"connection"); + if connection.contains(&b"close".to_vec()) { + return false; + } + if event.http_version() < &b"1.1".to_vec() { + return false; + } + return true; +} + +fn _body_framing>(request_method: &[u8], event: T) -> (&str, isize) { + let event: RequestOrResponse = event.into(); + if let RequestOrResponse::Response(response) = &event { + if response.status_code == 204 + || response.status_code == 304 + || request_method == b"HEAD" + || (request_method == b"CONNECT" + && 200 <= response.status_code + && response.status_code < 300) + { + return ("content-length", 0); + } + assert!(response.status_code >= 200); + } + + let trasfer_encodings = get_comma_header(event.headers(), b"transfer-encoding"); + if !trasfer_encodings.is_empty() { + assert!(trasfer_encodings == vec![b"chunked".to_vec()]); + return ("chunked", 0); + } + + let content_lengths = get_comma_header(event.headers(), b"content-length"); + if !content_lengths.is_empty() { + return ( + "content-length", + std::str::from_utf8(&content_lengths[0]) + .unwrap() + .parse() + .unwrap(), + ); + } + + if let RequestOrResponse::Request(_) = event { + return ("content-length", 0); + } else { + return ("http/1.0", 0); + } +} + +pub struct Connection { + pub our_role: Role, + pub their_role: Role, + _cstate: ConnectionState, + _writer: Option>, + _reader: Option>, + _max_incomplete_event_size: usize, + _receive_buffer: ReceiveBuffer, + _receive_buffer_closed: bool, + pub their_http_version: Option>, + _request_method: Option>, + client_is_waiting_for_100_continue: bool, +} + +impl Connection { + pub fn new(our_role: Role, max_incomplete_event_size: Option) -> Self { + Self { + our_role, + their_role: if our_role == Role::Client { + Role::Server + } else { + Role::Client + }, + _cstate: ConnectionState::new(), + _writer: match our_role { + Role::Client => Some(Box::new(write_request)), + Role::Server => Some(Box::new(write_response)), + }, + _reader: match our_role { + Role::Server => Some(Box::new(IdleClientReader {})), + Role::Client => Some(Box::new(SendResponseServerReader {})), + }, + _max_incomplete_event_size: max_incomplete_event_size + .unwrap_or(DEFAULT_MAX_INCOMPLETE_EVENT_SIZE), + _receive_buffer: ReceiveBuffer::new(), + _receive_buffer_closed: false, + their_http_version: None, + _request_method: None, + client_is_waiting_for_100_continue: false, + } + } + + pub fn get_states(&self) -> HashMap { + self._cstate.states.clone() + } + + pub fn get_our_state(&self) -> State { + self._cstate.states[&self.our_role] + } + + pub fn get_their_state(&self) -> State { + self._cstate.states[&self.their_role] + } + + pub fn get_client_is_waiting_for_100_continue(&self) -> bool { + self.client_is_waiting_for_100_continue + } + + pub fn get_they_are_waiting_for_100_continue(&self) -> bool { + self.their_role == Role::Client && self.client_is_waiting_for_100_continue + } + + pub fn start_next_cycle(&mut self) -> Result<(), ProtocolError> { + let old_states = self._cstate.states.clone(); + self._cstate.start_next_cycle()?; + self._request_method = None; + self.their_http_version = None; + self.client_is_waiting_for_100_continue = false; + self._respond_to_state_changes(old_states, None); + Ok(()) + } + + fn _process_error(&mut self, role: Role) { + let old_states = self._cstate.states.clone(); + self._cstate.process_error(role); + self._respond_to_state_changes(old_states, None); + } + + fn _server_switch_event(&self, event: Event) -> Option { + if let Event::InformationalResponse(informational_response) = &event { + if informational_response.status_code == 101 { + return Some(Switch::SwitchUpgrade); + } + } + if let Event::NormalResponse(response) = &event { + if self + ._cstate + .pending_switch_proposals + .contains(&Switch::SwitchConnect) + && 200 <= response.status_code + && response.status_code < 300 + { + return Some(Switch::SwitchConnect); + } + } + return None; + } + + fn _process_event(&mut self, role: Role, event: Event) -> Result<(), ProtocolError> { + let old_states = self._cstate.states.clone(); + if role == Role::Client { + if let Event::Request(request) = event.clone() { + if request.method == b"CONNECT" { + self._cstate + .process_client_switch_proposal(Switch::SwitchConnect); + } + if get_comma_header(&request.headers, b"upgrade").len() > 0 { + self._cstate + .process_client_switch_proposal(Switch::SwitchUpgrade); + } + } + } + let server_switch_event = if role == Role::Server { + self._server_switch_event(event.clone()) + } else { + None + }; + self._cstate + .process_event(role, (&event).into(), server_switch_event)?; + + if let Event::Request(request) = event.clone() { + self._request_method = Some(request.method); + } + + if role == self.their_role { + if let Event::Request(request) = event.clone() { + self.their_http_version = Some(request.http_version); + } + if let Event::NormalResponse(response) = event.clone() { + self.their_http_version = Some(response.http_version); + } + if let Event::InformationalResponse(informational_response) = event.clone() { + self.their_http_version = Some(informational_response.http_version); + } + } + + if let Event::Request(request) = event.clone() { + if !_keep_alive(RequestOrResponse::from(request)) { + self._cstate.process_keep_alive_disabled(); + } + } + if let Event::NormalResponse(response) = event.clone() { + if !_keep_alive(RequestOrResponse::from(response)) { + self._cstate.process_keep_alive_disabled(); + } + } + + if let Event::Request(request) = event.clone() { + if has_expect_100_continue(&request) { + self.client_is_waiting_for_100_continue = true; + } + } + match (&event).into() { + EventType::InformationalResponse => { + self.client_is_waiting_for_100_continue = false; + } + EventType::NormalResponse => { + self.client_is_waiting_for_100_continue = false; + } + EventType::Data => { + if role == Role::Client { + self.client_is_waiting_for_100_continue = false; + } + } + EventType::EndOfMessage => { + if role == Role::Client { + self.client_is_waiting_for_100_continue = false; + } + } + _ => {} + } + + self._respond_to_state_changes(old_states, Some(event)); + Ok(()) + } + + fn _respond_to_state_changes( + &mut self, + old_states: HashMap, + event: Option, + ) { + if self.get_our_state() != old_states[&self.our_role] { + let state = self._cstate.states[&self.our_role]; + self._writer = match state { + State::SendBody => { + let request_method = self._request_method.clone().unwrap_or(vec![]); + let (framing_type, length) = _body_framing( + &request_method, + RequestOrResponse::from(event.clone().unwrap()), + ); + + match framing_type { + "content-length" => Some(Box::new(content_length_writer(length))), + "chunked" => Some(Box::new(chunked_writer())), + "http/1.0" => Some(Box::new(http10_writer())), + _ => { + panic!("Invalid role and framing type combination"); + } + } + } + _ => match (&self.our_role, state) { + (Role::Client, State::Idle) => Some(Box::new(write_request)), + (Role::Server, State::Idle) => Some(Box::new(write_response)), + (Role::Server, State::SendResponse) => Some(Box::new(write_response)), + _ => None, + }, + }; + } + if self.get_their_state() != old_states[&self.their_role] { + self._reader = match self._cstate.states[&self.their_role] { + State::SendBody => { + let request_method = self._request_method.clone().unwrap_or(vec![]); + let (framing_type, length) = _body_framing( + &request_method, + RequestOrResponse::from(event.clone().unwrap()), + ); + match framing_type { + "content-length" => { + Some(Box::new(ContentLengthReader::new(length as usize))) + } + "chunked" => Some(Box::new(ChunkedReader::new())), + "http/1.0" => Some(Box::new(Http10Reader {})), + _ => { + panic!("Invalid role and framing type combination"); + } + } + } + _ => match (&self.their_role, self._cstate.states[&self.their_role]) { + (Role::Client, State::Idle) => Some(Box::new(IdleClientReader {})), + (Role::Server, State::Idle) => Some(Box::new(SendResponseServerReader {})), + (Role::Server, State::SendResponse) => { + Some(Box::new(SendResponseServerReader {})) + } + (Role::Client, State::Done) => Some(Box::new(ClosedReader {})), + (Role::Client, State::MustClose) => Some(Box::new(ClosedReader {})), + (Role::Client, State::Closed) => Some(Box::new(ClosedReader {})), + (Role::Server, State::Done) => Some(Box::new(ClosedReader {})), + (Role::Server, State::MustClose) => Some(Box::new(ClosedReader {})), + (Role::Server, State::Closed) => Some(Box::new(ClosedReader {})), + _ => None, + }, + }; + } + } + + pub fn get_trailing_data(&self) -> (Vec, bool) { + ( + self._receive_buffer.bytes().to_vec(), + self._receive_buffer_closed, + ) + } + + pub fn receive_data(&mut self, data: &[u8]) -> Result<(), String> { + Ok(if data.len() > 0 { + if self._receive_buffer_closed { + return Err("received close, then received more data?".to_string()); + } + self._receive_buffer.add(data); + } else { + self._receive_buffer_closed = true; + }) + } + + fn _extract_next_receive_event(&mut self) -> Result { + let state = self.get_their_state(); + if state == State::Done && self._receive_buffer.len() > 0 { + return Ok(Event::Paused()); + } + if state == State::MightSwitchProtocol || state == State::SwitchedProtocol { + return Ok(Event::Paused()); + } + let event = self + ._reader + .as_mut() + .unwrap() + .call(&mut self._receive_buffer)?; + if event.is_none() { + if self._receive_buffer.len() == 0 && self._receive_buffer_closed { + return self._reader.as_mut().unwrap().read_eof(); + } + } + Ok(event.unwrap_or(Event::NeedData())) + } + + pub fn next_event(&mut self) -> Result { + if self.get_their_state() == State::Error { + return Err(ProtocolError::RemoteProtocolError( + "Can't receive data when peer state is ERROR".into(), + )); + } + match (|| { + let event = self._extract_next_receive_event()?; + match event { + Event::NeedData() | Event::Paused() => {} + _ => { + self._process_event(self.their_role, event.clone())?; + } + }; + + if let Event::NeedData() = event.clone() { + if self._receive_buffer.len() > self._max_incomplete_event_size { + return Err(ProtocolError::RemoteProtocolError( + ("Receive buffer too long".to_string(), 431).into(), + )); + } + if self._receive_buffer_closed { + return Err(ProtocolError::RemoteProtocolError( + "peer unexpectedly closed connection".to_string().into(), + )); + } + } + + Ok(event) + })() { + Err(error) => { + self._process_error(self.their_role); + match error { + ProtocolError::LocalProtocolError(error) => { + Err(error._reraise_as_remote_protocol_error().into()) + } + _ => Err(error), + } + } + Ok(any) => Ok(any), + } + } + + pub fn send(&mut self, mut event: Event) -> Result>, ProtocolError> { + if self.get_our_state() == State::Error { + return Err(ProtocolError::LocalProtocolError( + "Can't send data when our state is ERROR".to_string().into(), + )); + } + event = if let Event::NormalResponse(response) = &event { + Event::NormalResponse(self._clean_up_response_headers_for_sending(response.clone())?) + } else { + event + }; + let event_type: EventType = (&event).into(); + let res: Result, ProtocolError> = match self._writer.as_mut() { + Some(_) if event_type == EventType::ConnectionClosed => Ok(vec![]), + Some(writer) => writer(event.clone()), + None => Err(ProtocolError::LocalProtocolError( + "Can't send data when our state is not SEND_BODY" + .to_string() + .into(), + )), + }; + self._process_event(self.our_role, event.clone())?; + if event_type == EventType::ConnectionClosed { + return Ok(None); + } else { + match res { + Ok(data_list) => Ok(Some(data_list)), + Err(error) => { + self._process_error(self.our_role); + Err(error) + } + } + } + } + + pub fn send_failed(&mut self) { + self._process_error(self.our_role); + } + + fn _clean_up_response_headers_for_sending( + &self, + response: Response, + ) -> Result { + let mut headers = response.clone().headers; + let mut need_close = false; + let mut method_for_choosing_headers = self._request_method.clone().unwrap_or(vec![]); + if method_for_choosing_headers == b"HEAD".to_vec() { + method_for_choosing_headers = b"GET".to_vec(); + } + let (framing_type, _) = _body_framing(&method_for_choosing_headers, response.clone()); + if framing_type == "chunked" || framing_type == "http/1.0" { + headers = set_comma_header(&headers, b"content-length", vec![])?; + if self + .their_http_version + .clone() + .map(|v| v < b"1.1".to_vec()) + .unwrap_or(true) + { + headers = set_comma_header(&headers, b"transfer-encoding", vec![])?; + if self._request_method.clone().unwrap_or(vec![]) != b"HEAD".to_vec() { + need_close = true; + } + } else { + headers = + set_comma_header(&headers, b"transfer-encoding", vec![b"chunked".to_vec()])?; + } + } + if !self._cstate.keep_alive || need_close { + let mut connection: HashSet> = get_comma_header(&headers, b"connection") + .into_iter() + .collect(); + connection.retain(|x| x != &b"keep-alive".to_vec()); + connection.insert(b"close".to_vec()); + headers = set_comma_header(&headers, b"connection", connection.into_iter().collect())?; + } + return Ok(Response { + headers, + status_code: response.status_code, + http_version: response.http_version, + reason: response.reason, + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_keep_alive() { + assert!(_keep_alive(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![(b"Host".to_vec(), b"Example.com".to_vec())].into(), + http_version: b"1.1".to_vec(), + })); + assert!(!_keep_alive(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"Example.com".to_vec()), + (b"Connection".to_vec(), b"close".to_vec()), + ] + .into(), + http_version: b"1.1".to_vec(), + })); + assert!(!_keep_alive(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"Example.com".to_vec()), + (b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()), + ] + .into(), + http_version: b"1.1".to_vec(), + })); + assert!(!_keep_alive(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![].into(), + http_version: b"1.0".to_vec(), + })); + + assert!(_keep_alive(Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"OK".to_vec(), + })); + assert!(!_keep_alive(Response { + status_code: 200, + headers: vec![(b"Connection".to_vec(), b"close".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"OK".to_vec(), + })); + assert!(!_keep_alive(Response { + status_code: 200, + headers: vec![(b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()),].into(), + http_version: b"1.1".to_vec(), + reason: b"OK".to_vec(), + })); + assert!(!_keep_alive(Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.0".to_vec(), + reason: b"OK".to_vec(), + })); + } + + #[test] + fn test_body_framing() { + fn headers(cl: Option, te: bool) -> Headers { + let mut headers = vec![]; + if let Some(cl) = cl { + headers.push(( + b"Content-Length".to_vec(), + cl.to_string().as_bytes().to_vec(), + )); + } + if te { + headers.push((b"Transfer-Encoding".to_vec(), b"chunked".to_vec())); + } + headers.push((b"Host".to_vec(), b"example.com".to_vec())); + return headers.into(); + } + + fn resp(status_code: u16, cl: Option, te: bool) -> Response { + Response { + status_code, + headers: headers(cl, te), + http_version: b"1.1".to_vec(), + reason: b"OK".to_vec(), + } + } + + fn req(cl: Option, te: bool) -> Request { + Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: headers(cl, te), + http_version: b"1.1".to_vec(), + } + } + + // Special cases where the headers are ignored: + for (cl, te) in vec![(Some(100), false), (None, true), (Some(100), true)] { + for (meth, r) in vec![ + (b"HEAD".to_vec(), resp(200, cl, te)), + (b"GET".to_vec(), resp(204, cl, te)), + (b"GET".to_vec(), resp(304, cl, te)), + ] { + assert_eq!(_body_framing(&meth, r), ("content-length", 0)); + } + } + + // Transfer-encoding + for (cl, te) in vec![(None, true), (Some(100), true)] { + for (meth, r) in vec![ + (b"".to_vec(), RequestOrResponse::from(req(cl, te))), + (b"GET".to_vec(), RequestOrResponse::from(resp(200, cl, te))), + ] { + assert_eq!(_body_framing(&meth, r), ("chunked", 0)); + } + } + + // Content-Length + for (meth, r) in vec![ + (b"".to_vec(), RequestOrResponse::from(req(Some(100), false))), + ( + b"GET".to_vec(), + RequestOrResponse::from(resp(200, Some(100), false)), + ), + ] { + assert_eq!(_body_framing(&meth, r), ("content-length", 100)); + } + + // No headers + assert_eq!(_body_framing(b"", req(None, false)), ("content-length", 0)); + assert_eq!( + _body_framing(b"GET", resp(200, None, false)), + ("http/1.0", 0) + ); + } +} diff --git a/src/_events.rs b/src/_events.rs new file mode 100644 index 0000000..879d903 --- /dev/null +++ b/src/_events.rs @@ -0,0 +1,141 @@ +use crate::_abnf::{METHOD, REQUEST_TARGET}; +use crate::{_headers::Headers, _util::ProtocolError}; +use lazy_static::lazy_static; +use regex::bytes::Regex; +use std::fmt::{self, Formatter}; + +lazy_static! { + static ref METHOD_RE: Regex = Regex::new(&format!(r"^{}$", *METHOD)).unwrap(); + static ref REQUEST_TARGET_RE: Regex = Regex::new(&format!(r"^{}$", *REQUEST_TARGET)).unwrap(); +} + +#[derive(Clone, PartialEq, Eq, Default)] +pub struct Request { + pub method: Vec, + pub headers: Headers, + pub target: Vec, + pub http_version: Vec, +} + +impl Request { + pub fn new( + method: Vec, + headers: Headers, + target: Vec, + http_version: Vec, + ) -> Result { + let mut host_count = 0; + for (name, _) in headers.iter() { + if name == b"host" { + host_count += 1; + } + } + if http_version == b"1.1" && host_count == 0 { + return Err(ProtocolError::LocalProtocolError( + ("Missing mandatory Host: header".to_string(), 400).into(), + )); + } + if host_count > 1 { + return Err(ProtocolError::LocalProtocolError( + ("Found multiple Host: headers".to_string(), 400).into(), + )); + } + + if !METHOD_RE.is_match(&method) { + return Err(ProtocolError::LocalProtocolError( + ("Illegal method characters".to_string(), 400).into(), + )); + } + if !REQUEST_TARGET_RE.is_match(&target) { + return Err(ProtocolError::LocalProtocolError( + ("Illegal target characters".to_string(), 400).into(), + )); + } + + Ok(Self { + method, + headers, + target, + http_version, + }) + } +} + +impl std::fmt::Debug for Request { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Request") + .field("method", &String::from_utf8_lossy(&self.method)) + .field("headers", &self.headers) + .field("target", &String::from_utf8_lossy(&self.target)) + .field("http_version", &String::from_utf8_lossy(&self.http_version)) + .finish() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Response { + pub headers: Headers, + pub http_version: Vec, + pub reason: Vec, + pub status_code: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Data { + pub data: Vec, + pub chunk_start: bool, + pub chunk_end: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct EndOfMessage { + pub headers: Headers, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ConnectionClosed {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + Request(Request), + NormalResponse(Response), + InformationalResponse(Response), + Data(Data), + EndOfMessage(EndOfMessage), + ConnectionClosed(ConnectionClosed), + NeedData(), + Paused(), +} + +impl From for Event { + fn from(request: Request) -> Self { + Self::Request(request) + } +} + +impl From for Event { + fn from(response: Response) -> Self { + match response.status_code { + 100..=199 => Self::InformationalResponse(response), + _ => Self::NormalResponse(response), + } + } +} + +impl From for Event { + fn from(data: Data) -> Self { + Self::Data(data) + } +} + +impl From for Event { + fn from(end_of_message: EndOfMessage) -> Self { + Self::EndOfMessage(end_of_message) + } +} + +impl From for Event { + fn from(connection_closed: ConnectionClosed) -> Self { + Self::ConnectionClosed(connection_closed) + } +} diff --git a/src/_headers.rs b/src/_headers.rs new file mode 100644 index 0000000..e850863 --- /dev/null +++ b/src/_headers.rs @@ -0,0 +1,594 @@ +use std::collections::HashSet; + +use crate::{ + _abnf::{FIELD_NAME, FIELD_VALUE}, + _events::Request, + _util::ProtocolError, +}; +use lazy_static::lazy_static; +use regex::bytes::Regex; + +lazy_static! { + static ref CONTENT_LENGTH_RE: Regex = Regex::new(r"^[0-9]+$").unwrap(); + static ref FIELD_NAME_RE: Regex = Regex::new(&format!(r"^{}$", FIELD_NAME)).unwrap(); + static ref FIELD_VALUE_RE: Regex = Regex::new(&format!(r"^{}$", *FIELD_VALUE)).unwrap(); +} + +#[derive(Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] +pub struct Headers(Vec<(Vec, Vec, Vec)>); + +impl std::fmt::Debug for Headers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut debug_struct = f.debug_struct("Headers"); + self.0.iter().for_each(|(raw_name, _, value)| { + debug_struct.field( + std::str::from_utf8(raw_name).unwrap(), + &std::str::from_utf8(value).unwrap(), + ); + }); + debug_struct.finish() + } +} + +impl Headers { + pub fn iter(&self) -> impl Iterator, Vec)> + '_ { + self.0 + .iter() + .map(|(_, name, value)| ((*name).clone(), (*value).clone())) + } + + pub fn raw_items(&self) -> Vec<&(Vec, Vec, Vec)> { + self.0.iter().collect() + } + + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl From, Vec)>> for Headers { + fn from(value: Vec<(Vec, Vec)>) -> Self { + normalize_and_validate(value, false).unwrap() + } +} + +pub fn normalize_and_validate( + headers: Vec<(Vec, Vec)>, + _parsed: bool, +) -> Result { + let mut new_headers = vec![]; + let mut seen_content_length = None; + let mut saw_transfer_encoding = false; + for (name, value) in headers { + if !_parsed { + if !FIELD_NAME_RE.is_match(&name) { + return Err(ProtocolError::LocalProtocolError( + format!("Illegal header name {:?}", &name).into(), + )); + } + if !FIELD_VALUE_RE.is_match(&value) { + return Err(ProtocolError::LocalProtocolError( + format!("Illegal header value {:?}", &value).into(), + )); + } + } + let raw_name = name.clone(); + let name = name.to_ascii_lowercase(); + if name == b"content-length" { + let lengths: HashSet> = value + .split(|&b| b == b',') + .map(|length| { + std::str::from_utf8(length) + .unwrap() + .trim() + .as_bytes() + .to_vec() + }) + .collect(); + if lengths.len() != 1 { + return Err(ProtocolError::LocalProtocolError( + "conflicting Content-Length headers".into(), + )); + } + let value = lengths.iter().next().unwrap(); + if !CONTENT_LENGTH_RE.is_match(value) { + return Err(ProtocolError::LocalProtocolError( + "bad Content-Length".into(), + )); + } + if seen_content_length.is_none() { + seen_content_length = Some(value.clone()); + new_headers.push((raw_name, name, value.clone())); + } else if seen_content_length != Some(value.clone()) { + return Err(ProtocolError::LocalProtocolError( + "conflicting Content-Length headers".into(), + )); + } + } else if name == b"transfer-encoding" { + // "A server that receives a request message with a transfer coding + // it does not understand SHOULD respond with 501 (Not + // Implemented)." + // https://tools.ietf.org/html/rfc7230#section-3.3.1 + if saw_transfer_encoding { + return Err(ProtocolError::LocalProtocolError( + ("multiple Transfer-Encoding headers", 501).into(), + )); + } + // "All transfer-coding names are case-insensitive" + // -- https://tools.ietf.org/html/rfc7230#section-4 + let value = value.to_ascii_lowercase(); + if value != b"chunked" { + return Err(ProtocolError::LocalProtocolError( + ("Only Transfer-Encoding: chunked is supported", 501).into(), + )); + } + saw_transfer_encoding = true; + new_headers.push((raw_name, name, value)); + } else { + new_headers.push((raw_name, name, value.to_vec())); + } + } + + Ok(Headers(new_headers)) +} + +pub fn get_comma_header(headers: &Headers, name: &[u8]) -> Vec> { + let mut out: Vec> = vec![]; + let name = name.to_ascii_lowercase(); + for (found_name, found_value) in headers.iter() { + if found_name == name { + for found_split_value in found_value.to_ascii_lowercase().split(|&b| b == b',') { + let found_split_value = std::str::from_utf8(found_split_value).unwrap().trim(); + if !found_split_value.is_empty() { + out.push(found_split_value.as_bytes().to_vec()); + } + } + } + } + out +} + +pub fn set_comma_header( + headers: &Headers, + name: &[u8], + new_values: Vec>, +) -> Result { + let mut new_headers = vec![]; + for (found_name, found_value) in headers.iter() { + if found_name != name { + new_headers.push((found_name, found_value)); + } + } + for new_value in new_values { + new_headers.push((name.to_vec(), new_value)); + } + normalize_and_validate(new_headers, false) +} + +pub fn has_expect_100_continue(request: &Request) -> bool { + // https://tools.ietf.org/html/rfc7231#section-5.1.1 + // "A server that receives a 100-continue expectation in an HTTP/1.0 request + // MUST ignore that expectation." + if request.http_version < b"1.1".to_vec() { + return false; + } + let expect = get_comma_header(&request.headers, b"expect"); + expect.contains(&b"100-continue".to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_and_validate() { + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"bar".to_vec())], false).unwrap(), + Headers(vec![(b"foo".to_vec(), b"foo".to_vec(), b"bar".to_vec())]) + ); + + // no leading/trailing whitespace in names + assert_eq!( + normalize_and_validate(vec![(b"foo ".to_vec(), b"bar".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("Illegal header name [102, 111, 111, 32]".to_string(), 400).into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b" foo".to_vec(), b"bar".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("Illegal header name [32, 102, 111, 111]".to_string(), 400).into() + ) + ); + + // no weird characters in names + assert_eq!( + normalize_and_validate(vec![(b"foo bar".to_vec(), b"baz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header name [102, 111, 111, 32, 98, 97, 114]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo\x00bar".to_vec(), b"baz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header name [102, 111, 111, 0, 98, 97, 114]".to_string(), + 400 + ) + .into() + ) + ); + // Not even 8-bit characters: + assert_eq!( + normalize_and_validate(vec![(b"foo\xffbar".to_vec(), b"baz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header name [102, 111, 111, 255, 98, 97, 114]".to_string(), + 400 + ) + .into() + ) + ); + // And not even the control characters we allow in values: + assert_eq!( + normalize_and_validate(vec![(b"foo\x01bar".to_vec(), b"baz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header name [102, 111, 111, 1, 98, 97, 114]".to_string(), + 400 + ) + .into() + ) + ); + + // no return or NUL characters in values + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"bar\rbaz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [98, 97, 114, 13, 98, 97, 122]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"bar\nbaz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [98, 97, 114, 10, 98, 97, 122]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"bar\x00baz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [98, 97, 114, 0, 98, 97, 122]".to_string(), + 400 + ) + .into() + ) + ); + // no leading/trailing whitespace + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz ".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [98, 97, 114, 98, 97, 122, 32, 32]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b" barbaz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [32, 32, 98, 97, 114, 98, 97, 122]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz\t".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [98, 97, 114, 98, 97, 122, 9]".to_string(), + 400 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate(vec![(b"foo".to_vec(), b"\tbarbaz".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Illegal header value [9, 98, 97, 114, 98, 97, 122]".to_string(), + 400 + ) + .into() + ) + ); + + // content-length + assert_eq!( + normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1".to_vec())], false) + .unwrap(), + Headers(vec![( + b"Content-Length".to_vec(), + b"content-length".to_vec(), + b"1".to_vec() + )]) + ); + assert_eq!( + normalize_and_validate(vec![(b"Content-Length".to_vec(), b"asdf".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into()) + ); + assert_eq!( + normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1x".to_vec())], false) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into()) + ); + assert_eq!( + normalize_and_validate( + vec![ + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Content-Length".to_vec(), b"2".to_vec()) + ], + false + ) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("conflicting Content-Length headers".to_string(), 400).into() + ) + ); + assert_eq!( + normalize_and_validate( + vec![ + (b"Content-Length".to_vec(), b"0".to_vec()), + (b"Content-Length".to_vec(), b"0".to_vec()) + ], + false + ) + .unwrap(), + Headers(vec![( + b"Content-Length".to_vec(), + b"content-length".to_vec(), + b"0".to_vec() + )]) + ); + assert_eq!( + normalize_and_validate(vec![(b"Content-Length".to_vec(), b"0 , 0".to_vec())], false) + .unwrap(), + Headers(vec![( + b"Content-Length".to_vec(), + b"content-length".to_vec(), + b"0".to_vec() + )]) + ); + assert_eq!( + normalize_and_validate( + vec![ + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Content-Length".to_vec(), b"2".to_vec()) + ], + false + ) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("conflicting Content-Length headers".to_string(), 400).into() + ) + ); + assert_eq!( + normalize_and_validate( + vec![(b"Content-Length".to_vec(), b"1 , 1,2".to_vec())], + false + ) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("conflicting Content-Length headers".to_string(), 400).into() + ) + ); + + // transfer-encoding + assert_eq!( + normalize_and_validate( + vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())], + false + ) + .unwrap(), + Headers(vec![( + b"Transfer-Encoding".to_vec(), + b"transfer-encoding".to_vec(), + b"chunked".to_vec() + )]) + ); + assert_eq!( + normalize_and_validate( + vec![(b"Transfer-Encoding".to_vec(), b"cHuNkEd".to_vec())], + false + ) + .unwrap(), + Headers(vec![( + b"Transfer-Encoding".to_vec(), + b"transfer-encoding".to_vec(), + b"chunked".to_vec() + )]) + ); + assert_eq!( + normalize_and_validate( + vec![(b"Transfer-Encoding".to_vec(), b"gzip".to_vec())], + false + ) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ( + "Only Transfer-Encoding: chunked is supported".to_string(), + 501 + ) + .into() + ) + ); + assert_eq!( + normalize_and_validate( + vec![ + (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()), + (b"Transfer-Encoding".to_vec(), b"gzip".to_vec()) + ], + false + ) + .expect_err("Expect ProtocolError::LocalProtocolError"), + ProtocolError::LocalProtocolError( + ("multiple Transfer-Encoding headers".to_string(), 501).into() + ) + ); + } + + #[test] + fn test_get_set_comma_header() { + let headers = normalize_and_validate( + vec![ + (b"Connection".to_vec(), b"close".to_vec()), + (b"whatever".to_vec(), b"something".to_vec()), + (b"connectiON".to_vec(), b"fOo,, , BAR".to_vec()), + ], + false, + ) + .unwrap(); + + assert_eq!( + get_comma_header(&headers, b"connection"), + vec![b"close".to_vec(), b"foo".to_vec(), b"bar".to_vec()] + ); + + let headers = + set_comma_header(&headers, b"newthing", vec![b"a".to_vec(), b"b".to_vec()]).unwrap(); + + assert_eq!( + headers, + Headers(vec![ + ( + b"connection".to_vec(), + b"connection".to_vec(), + b"close".to_vec() + ), + ( + b"whatever".to_vec(), + b"whatever".to_vec(), + b"something".to_vec() + ), + ( + b"connection".to_vec(), + b"connection".to_vec(), + b"fOo,, , BAR".to_vec() + ), + (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()), + (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()), + ]) + ); + + let headers = + set_comma_header(&headers, b"whatever", vec![b"different thing".to_vec()]).unwrap(); + + assert_eq!( + headers, + Headers(vec![ + ( + b"connection".to_vec(), + b"connection".to_vec(), + b"close".to_vec() + ), + ( + b"connection".to_vec(), + b"connection".to_vec(), + b"fOo,, , BAR".to_vec() + ), + (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()), + (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()), + ( + b"whatever".to_vec(), + b"whatever".to_vec(), + b"different thing".to_vec() + ), + ]) + ); + } + + #[test] + fn test_has_100_continue() { + assert!(has_expect_100_continue(&Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: normalize_and_validate( + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Expect".to_vec(), b"100-continue".to_vec()) + ], + false + ) + .unwrap(), + http_version: b"1.1".to_vec(), + })); + assert!(!has_expect_100_continue(&Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: normalize_and_validate( + vec![(b"Host".to_vec(), b"example.com".to_vec())], + false + ) + .unwrap(), + http_version: b"1.1".to_vec(), + })); + // Case insensitive + assert!(has_expect_100_continue(&Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: normalize_and_validate( + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Expect".to_vec(), b"100-Continue".to_vec()) + ], + false + ) + .unwrap(), + http_version: b"1.1".to_vec(), + })); + // Doesn't work in HTTP/1.0 + assert!(!has_expect_100_continue(&Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: normalize_and_validate( + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Expect".to_vec(), b"100-continue".to_vec()) + ], + false + ) + .unwrap(), + http_version: b"1.0".to_vec(), + })); + } +} diff --git a/src/_readers.rs b/src/_readers.rs new file mode 100644 index 0000000..46fab4b --- /dev/null +++ b/src/_readers.rs @@ -0,0 +1,692 @@ +use crate::{ + _abnf::{CHUNK_HEADER, HEADER_FIELD, REQUEST_LINE, STATUS_LINE}, + _events::{ConnectionClosed, Data, EndOfMessage, Event, Request, Response}, + _headers::normalize_and_validate, + _receivebuffer::ReceiveBuffer, + _util::ProtocolError, +}; +use lazy_static::lazy_static; +use regex::bytes::Regex; + +lazy_static! { + static ref HEADER_FIELD_RE: Regex = Regex::new(&format!(r"^{}$", *HEADER_FIELD)).unwrap(); + static ref OBS_FOLD_RE: Regex = Regex::new(r"^[ \t]+").unwrap(); +} + +fn _obsolete_line_fold(lines: Vec<&[u8]>) -> Result>, ProtocolError> { + let mut out = vec![]; + let mut it = lines.iter(); + let mut last: Option> = None; + while let Some(line) = it.next() { + let match_ = OBS_FOLD_RE.find(line); + if let Some(match_) = match_ { + if last.is_none() { + return Err(ProtocolError::LocalProtocolError( + "continuation line at start of headers".into(), + )); + } + if let Some(last) = last.as_mut() { + last.extend_from_slice(b" "); + last.extend_from_slice(&line[match_.end()..]); + } + } else { + if let Some(last) = last.take() { + out.push(last); + } + last = Some(line.to_vec()); + } + } + if let Some(last) = last.take() { + out.push(last); + } + Ok(out) +} + +fn _decode_header_lines(lines: Vec>) -> Result, Vec)>, ProtocolError> { + let lines = _obsolete_line_fold(lines.iter().map(|line| line.as_slice()).collect())?; + let mut out = vec![]; + for line in lines { + let matches = match HEADER_FIELD_RE.captures(&line) { + Some(matches) => matches, + None => { + return Err(ProtocolError::LocalProtocolError( + format!("illegal header line {:?}", &line).into(), + )) + } + }; + out.push(( + matches["field_name"].to_vec(), + matches["field_value"].to_vec(), + )); + } + Ok(out) +} + +lazy_static! { + static ref REQUEST_LINE_RE: Regex = Regex::new(&format!(r"^{}$", *REQUEST_LINE)).unwrap(); +} + +pub trait Reader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError>; + fn read_eof(&self) -> Result { + Ok(ConnectionClosed::default().into()) + } +} + +#[derive(Clone)] +pub struct IdleClientReader {} + +impl Reader for IdleClientReader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + let lines = buf.maybe_extract_lines(); + if lines.is_none() { + if buf.is_next_line_obviously_invalid_request_line() { + return Err(ProtocolError::LocalProtocolError( + ("illegal request line".to_string(), 400).into(), + )); + } + return Ok(None); + } + let lines = lines.unwrap(); + if lines.is_empty() { + return Err(ProtocolError::LocalProtocolError( + ("no request line received".to_string(), 400).into(), + )); + } + let matches = match REQUEST_LINE_RE.captures(&lines[0]) { + Some(matches) => matches, + None => { + return Err(ProtocolError::LocalProtocolError( + format!("illegal request line {:?}", std::str::from_utf8(&lines[0])).into(), + )) + } + }; + + let headers = normalize_and_validate(_decode_header_lines(lines[1..].to_vec())?, true)?; + + Ok(Some( + Request::new( + matches["method"].to_vec(), + headers, + matches["target"].to_vec(), + matches["http_version"].to_vec(), + )? + .into(), + )) + } +} + +lazy_static! { + static ref STATUS_LINE_RE: Regex = Regex::new(&STATUS_LINE).unwrap(); +} + +#[derive(Clone)] +pub struct SendResponseServerReader {} + +impl Reader for SendResponseServerReader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + let lines = buf.maybe_extract_lines(); + if lines.is_none() { + if buf.is_next_line_obviously_invalid_request_line() { + return Err(ProtocolError::LocalProtocolError( + ("illegal request line".to_string(), 400).into(), + )); + } + return Ok(None); + } + let lines = lines.unwrap(); + if lines.is_empty() { + return Err(ProtocolError::LocalProtocolError( + ("no response line received".to_string(), 400).into(), + )); + } + let matches = match STATUS_LINE_RE.captures(&lines[0]) { + Some(matches) => matches, + None => { + return Err(ProtocolError::LocalProtocolError( + format!("illegal response line {:?}", &lines[0]).into(), + )) + } + }; + let http_version = matches["http_version"].to_vec(); + let reason = matches["reason"].to_vec(); + let status_code: u16 = match std::str::from_utf8(&matches["status_code"].to_vec()) { + Ok(status_code) => match status_code.parse() { + Ok(status_code) => status_code, + Err(_) => { + return Err(ProtocolError::LocalProtocolError( + ("illegal status code".to_string(), 400).into(), + )) + } + }, + Err(_) => { + return Err(ProtocolError::LocalProtocolError( + ("illegal status code".to_string(), 400).into(), + )) + } + }; + let headers = normalize_and_validate(_decode_header_lines(lines[1..].to_vec())?, true)?; + + return Ok(Some(Event::from(Response { + headers, + http_version, + reason, + status_code, + }))); + } +} + +#[derive(Clone)] +pub struct ContentLengthReader { + length: usize, + remaining: usize, +} + +impl ContentLengthReader { + pub fn new(length: usize) -> Self { + Self { + length, + remaining: length, + } + } +} + +impl Reader for ContentLengthReader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + if self.remaining == 0 { + return Ok(Some(EndOfMessage::default().into())); + } + match buf.maybe_extract_at_most(self.remaining) { + Some(data) => { + self.remaining -= data.len(); + Ok(Some( + Data { + data, + chunk_start: false, + chunk_end: false, + } + .into(), + )) + } + None => Ok(None), + } + } + + fn read_eof(&self) -> Result { + Err(ProtocolError::RemoteProtocolError( + ( + format!( + "peer closed connection without sending complete message body \ + (received {} bytes, expected {})", + self.length - self.remaining, + self.length + ), + 400, + ) + .into(), + )) + } +} + +lazy_static! { + static ref CHUNK_HEADER_RE: Regex = Regex::new(&CHUNK_HEADER).unwrap(); +} + +#[derive(Clone)] +pub struct ChunkedReader { + bytes_in_chunk: usize, + bytes_to_discard: usize, + reading_trailer: bool, +} + +impl ChunkedReader { + pub fn new() -> Self { + Self { + bytes_in_chunk: 0, + bytes_to_discard: 0, + reading_trailer: false, + } + } +} + +impl Reader for ChunkedReader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + if self.reading_trailer { + match buf.maybe_extract_lines() { + Some(lines) => { + return Ok(Some( + EndOfMessage { + headers: normalize_and_validate(_decode_header_lines(lines)?, true)?, + } + .into(), + )) + } + None => return Ok(None), + } + } + if self.bytes_to_discard > 0 { + let data = buf.maybe_extract_at_most(self.bytes_to_discard); + if data.is_none() { + return Ok(None); + } + self.bytes_to_discard -= data.unwrap().len(); + if self.bytes_to_discard > 0 { + return Ok(None); + } + } + assert_eq!(self.bytes_to_discard, 0); + let chunk_start: bool; + if self.bytes_in_chunk == 0 { + if let Some(chunk_header) = buf.maybe_extract_next_line() { + let matches = match CHUNK_HEADER_RE.captures(&chunk_header) { + Some(matches) => matches, + None => { + return Err(ProtocolError::LocalProtocolError( + format!("illegal chunk header: {:?}", &chunk_header).into(), + )) + } + }; + self.bytes_in_chunk = match usize::from_str_radix( + std::str::from_utf8(&matches["chunk_size"].to_vec()).unwrap(), + 16, + ) { + Ok(bytes_in_chunk) => bytes_in_chunk, + Err(_) => { + return Err(ProtocolError::LocalProtocolError( + format!("illegal chunk size: {:?}", &matches["chunk_size"]).into(), + )) + } + }; + if self.bytes_in_chunk == 0 { + self.reading_trailer = true; + return self.call(buf); + } + } else { + return Ok(None); + } + chunk_start = true; + } else { + chunk_start = false; + } + assert!(self.bytes_in_chunk > 0); + + if let Some(data) = buf.maybe_extract_at_most(self.bytes_in_chunk) { + self.bytes_in_chunk -= data.len(); + if self.bytes_in_chunk == 0 { + self.bytes_to_discard = 2; + } + let chunk_end = self.bytes_in_chunk == 0; + Ok(Some( + Data { + data, + chunk_start, + chunk_end, + } + .into(), + )) + } else { + Ok(None) + } + } + + fn read_eof(&self) -> Result { + Err(ProtocolError::RemoteProtocolError( + ( + "peer closed connection without sending complete message body \ + (incomplete chunked read)" + .to_string(), + 400, + ) + .into(), + )) + } +} + +#[derive(Clone)] +pub struct Http10Reader {} + +impl Reader for Http10Reader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + let data = buf.maybe_extract_at_most(999999999); + match data { + Some(data) => Ok(Some( + Data { + data, + chunk_start: false, + chunk_end: false, + } + .into(), + )), + None => Ok(None), + } + } + + fn read_eof(&self) -> Result { + Ok(EndOfMessage::default().into()) + } +} + +#[derive(Clone)] +pub struct ClosedReader {} + +impl Reader for ClosedReader { + fn call(&mut self, buf: &mut ReceiveBuffer) -> Result, ProtocolError> { + if buf.len() > 0 { + return Err(ProtocolError::LocalProtocolError( + ("unexpected data".to_string(), 400).into(), + )); + } + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_obsolete_line_fold_bytes() { + assert_eq!( + _obsolete_line_fold(vec![ + b"aaa".as_ref(), + b"bbb".as_ref(), + b" ccc".as_ref(), + b"ddd".as_ref() + ]) + .unwrap(), + vec![b"aaa".to_vec(), b"bbb ccc".to_vec(), b"ddd".to_vec()] + ); + } + + fn normalize_data_events(in_events: Vec) -> Vec { + let mut out_events = Vec::new(); + for in_event in in_events { + let event = match in_event { + Event::Data(data) => Event::Data(Data { + data: data.data.clone(), + chunk_start: false, + chunk_end: false, + }), + _ => in_event.clone(), + }; + if !out_events.is_empty() { + if let Event::Data(data_event) = &event { + if let Event::Data(last_data_event) = out_events.last().unwrap() { + let mut x = last_data_event.clone(); + x.data.extend_from_slice(&data_event.data); + let l = out_events.len(); + out_events[l - 1] = Event::Data(x); + continue; + } + } + } + out_events.push(event); + } + return out_events; + } + + fn _run_reader( + reader: &mut impl Reader, + buf: &mut ReceiveBuffer, + do_eof: bool, + ) -> Result, ProtocolError> { + let mut events = vec![]; + { + loop { + match reader.call(buf)? { + Some(event) => { + events.push(event.clone()); + if let Event::EndOfMessage(_) = event { + break; + } + } + None => break, + } + } + if do_eof { + assert!(buf.len() == 0); + events.push(reader.read_eof().unwrap()); + } + } + return Ok(normalize_data_events(events)); + } + + fn t_body_reader( + reader: impl Reader + Clone, + data: &[u8], + expected: Vec, + do_eof: bool, + ) -> Result<(), ProtocolError> { + assert_eq!( + _run_reader( + &mut reader.clone(), + &mut ReceiveBuffer::from(data.to_vec()), + do_eof + )?, + expected + ); + + let mut buf = ReceiveBuffer::new(); + let mut events = vec![]; + let mut r1 = reader.clone(); + for i in 0..data.len() { + events.append(&mut _run_reader(&mut r1, &mut buf, false)?); + buf.add(&mut data[i..i + 1].to_vec()); + } + events.append(&mut _run_reader(&mut r1, &mut buf, do_eof)?); + assert_eq!(normalize_data_events(events.clone()), expected); + + if expected.iter().any(|event| match event { + Event::EndOfMessage(_) => true, + _ => false, + }) && !do_eof + { + assert_eq!( + _run_reader( + &mut reader.clone(), + &mut ReceiveBuffer::from(data.to_vec()), + false + )?, + expected.clone() + ); + } + Ok(()) + } + + #[test] + fn test_content_length_reader() { + t_body_reader( + ContentLengthReader::new(0), + b"", + vec![EndOfMessage::default().into()], + false, + ) + .unwrap(); + + t_body_reader( + ContentLengthReader::new(10), + b"0123456789", + vec![ + Data { + data: b"0123456789".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + false, + ) + .unwrap(); + } + + #[test] + fn test_http10_reader() { + t_body_reader( + Http10Reader {}, + b"", + vec![EndOfMessage::default().into()], + true, + ) + .unwrap(); + + t_body_reader( + Http10Reader {}, + b"asdf", + vec![Data { + data: b"asdf".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + false, + ) + .unwrap(); + + t_body_reader( + Http10Reader {}, + b"asdf", + vec![ + Data { + data: b"asdf".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + true, + ) + .unwrap(); + } + + #[test] + fn test_chunked_reader() { + t_body_reader( + ChunkedReader::new(), + b"0\r\n\r\n", + vec![EndOfMessage::default().into()], + false, + ) + .unwrap(); + + t_body_reader( + ChunkedReader::new(), + b"0\r\nSome: header\r\n\r\n", + vec![EndOfMessage { + headers: vec![(b"Some".to_vec(), b"header".to_vec())].into(), + } + .into()], + false, + ) + .unwrap(); + + t_body_reader( + ChunkedReader::new(), + b"5\r\n01234\r\n10\r\n0123456789abcdef\r\n0\r\nSome: header\r\n\r\n", + vec![ + Data { + data: b"012340123456789abcdef".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage { + headers: vec![(b"Some".to_vec(), b"header".to_vec())].into(), + } + .into(), + ], + false, + ) + .unwrap(); + + t_body_reader( + ChunkedReader::new(), + b"5\r\n01234\r\n10\r\n0123456789abcdef\r\n0\r\n\r\n", + vec![ + Data { + data: b"012340123456789abcdef".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + false, + ) + .unwrap(); + + // handles upper and lowercase hex + t_body_reader( + ChunkedReader::new(), + &[ + b"aA\r\n".to_vec(), + vec![120; 0xAA], + b"\r\n".to_vec(), + b"0\r\n\r\n".to_vec(), + ] + .concat(), + vec![ + Data { + data: vec![120; 0xAA], + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + false, + ) + .unwrap(); + + // refuses arbitrarily long chunk integers + assert!(t_body_reader( + ChunkedReader::new(), + &[vec![57; 100], b"\r\nxxx".to_vec()].concat(), + vec![Data { + data: b"xxx".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + false, + ) + .is_err()); + + // refuses garbage in the chunk count + assert!(t_body_reader(ChunkedReader::new(), b"10\x00\r\nxxx", vec![], false,).is_err()); + + // handles (and discards) "chunk extensions" omg wtf + t_body_reader( + ChunkedReader::new(), + b"5; hello=there\r\nxxxxx\r\n0; random=\"junk\"; some=more; canbe=lonnnnngg\r\n\r\n", + vec![ + Data { + data: b"xxxxx".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + false, + ) + .unwrap(); + + t_body_reader( + ChunkedReader::new(), + b"5 \r\n01234\r\n0\r\n\r\n", + vec![ + Data { + data: b"01234".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + false, + ) + .unwrap(); + } +} diff --git a/src/_receivebuffer.rs b/src/_receivebuffer.rs new file mode 100644 index 0000000..a6213e7 --- /dev/null +++ b/src/_receivebuffer.rs @@ -0,0 +1,230 @@ +use lazy_static::lazy_static; +use regex::bytes::Regex; +use std::cmp::min; + +lazy_static! { + static ref BLANK_LINE_REGEX: Regex = Regex::new(r"\n\r?\n").unwrap(); +} + +pub struct ReceiveBuffer { + data: Vec, + next_line_search: usize, + multiple_lines_search: usize, +} + +impl ReceiveBuffer { + pub fn new() -> Self { + Self { + data: vec![], + next_line_search: 0, + multiple_lines_search: 0, + } + } + + pub fn add(&mut self, byteslike: &[u8]) { + self.data.extend(byteslike); + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn bytes(&self) -> &[u8] { + &self.data + } + + fn extract(&mut self, count: usize) -> Vec { + let out = self.data.drain(..min(count, self.data.len())).collect(); + self.next_line_search = 0; + self.multiple_lines_search = 0; + out + } + + pub fn maybe_extract_at_most(&mut self, count: usize) -> Option> { + if count == 0 || self.data.is_empty() { + None + } else { + Some(self.extract(count)) + } + } + + pub fn maybe_extract_next_line(&mut self) -> Option> { + let search_start_index = self.next_line_search.saturating_sub(1); + let needle = b"\r\n"; + let partial_idx = self.data[search_start_index..] + .windows(needle.len()) + .position(|window| window == needle); + match partial_idx { + Some(idx) => Some(self.extract(search_start_index + idx + needle.len())), + None => { + self.next_line_search = self.data.len(); + None + } + } + } + + pub fn maybe_extract_lines(&mut self) -> Option>> { + let lf: &[u8] = b"\n"; + if &self.data[..min(1, self.data.len())] == lf { + self.extract(1); + return Some(vec![]); + } + let crlf: &[u8] = b"\r\n"; + if &self.data[..min(2, self.data.len())] == crlf { + self.extract(2); + return Some(vec![]); + } + let match_ = BLANK_LINE_REGEX.find(&self.data[self.multiple_lines_search..]); + if match_.is_none() { + self.multiple_lines_search = self.data.len().saturating_sub(2); + None + } else { + let idx = match_.unwrap().end(); + let out = self.extract(self.multiple_lines_search + idx); + let mut lines = out + .split(|&b| b == b'\n') + .map(|line| { + let mut line = line.to_vec(); + if line.ends_with(&[b'\r']) { + line.pop(); + } + line + }) + .collect::>(); + assert_eq!(lines[lines.len() - 2], b"", "lines: {:?}", lines); + assert_eq!(lines[lines.len() - 1], b"", "lines: {:?}", lines); + lines.pop(); + lines.pop(); + Some(lines) + } + } + + pub fn is_next_line_obviously_invalid_request_line(&self) -> bool { + if self.data.is_empty() { + return false; + } + self.data[0] < 0x21 + } +} + +impl From> for ReceiveBuffer { + fn from(data: Vec) -> Self { + Self { + data, + next_line_search: 0, + multiple_lines_search: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_receivebuffer() { + let mut b = ReceiveBuffer::new(); + assert_eq!(b.len(), 0); + assert_eq!(b.extract(0), b""); + assert_eq!(b.maybe_extract_at_most(10), None); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + b.add(b"123"); + assert_eq!(b.len(), 3); + assert_eq!(b.extract(2), b"12"); + assert_eq!(b.len(), 1); + assert_eq!(b.extract(1), b"3"); + assert_eq!(b.len(), 0); + assert_eq!(b.maybe_extract_at_most(10), None); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + } + + #[test] + fn test_receivebuffer_maybe_extract_until_next() { + let mut b = ReceiveBuffer::new(); + b.add(b"123\n456\r\n789\r\n"); + assert_eq!(b.maybe_extract_next_line(), Some(b"123\n456\r\n".to_vec())); + assert_eq!(b.maybe_extract_next_line(), Some(b"789\r\n".to_vec())); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + b.add(b"12\r"); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + b.add(b"345\n\r"); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + b.add(b"\n6789aaa123\r\n"); + assert_eq!(b.maybe_extract_next_line(), Some(b"12\r345\n\r\n".to_vec())); + assert_eq!( + b.maybe_extract_next_line(), + Some(b"6789aaa123\r\n".to_vec()) + ); + assert_eq!(b.maybe_extract_next_line(), None); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + } + + #[test] + fn test_receivebuffer_maybe_extract_lines() { + let mut b = ReceiveBuffer::new(); + b.add(b"123\r\na: b\r\nfoo:bar\r\n\r\ntrailing"); + let lines = b.maybe_extract_lines(); + assert_eq!( + lines, + Some(vec![b"123".to_vec(), b"a: b".to_vec(), b"foo:bar".to_vec()]) + ); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + b.add(b"\r\n\r"); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + assert_eq!( + b.maybe_extract_at_most(100), + Some(b"trailing\r\n\r".to_vec()) + ); + assert_eq!(b.maybe_extract_at_most(100), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + + // Empty body case (as happens at the end of chunked encoding if there are + // no trailing headers, e.g.) + b.add(b"\r\ntrailing"); + assert_eq!(b.maybe_extract_lines(), Some(vec![])); + assert_eq!(b.maybe_extract_lines(), None); + assert!(!b.is_next_line_obviously_invalid_request_line()); + } + + #[test] + fn test_receivebuffer_for_invalid_delimiter() { + let mut b = ReceiveBuffer::new(); + + b.add(b"HTTP/1.1 200 OK\r\n"); + b.add(b"Content-type: text/plain\r\n"); + b.add(b"Connection: close\r\n"); + b.add(b"\r\n"); + b.add(b"Some body"); + + let lines = b.maybe_extract_lines(); + + assert_eq!( + lines, + Some(vec![ + b"HTTP/1.1 200 OK".to_vec(), + b"Content-type: text/plain".to_vec(), + b"Connection: close".to_vec(), + ]) + ); + assert_eq!(b.data, b"Some body"); + } +} diff --git a/src/_state.rs b/src/_state.rs new file mode 100644 index 0000000..5694dca --- /dev/null +++ b/src/_state.rs @@ -0,0 +1,701 @@ +use crate::{Event, ProtocolError}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub enum Role { + Client, + Server, +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub enum State { + Idle, + SendResponse, + SendBody, + Done, + MustClose, + Closed, + Error, + MightSwitchProtocol, + SwitchedProtocol, +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub enum Switch { + SwitchUpgrade, + SwitchConnect, + Client, +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub enum EventType { + Request, + InformationalResponse, + NormalResponse, + Data, + EndOfMessage, + ConnectionClosed, + NeedData, + Paused, + // Combination of EventType and Sentinel + RequestClient, // (Request, Switch::Client) + InformationalResponseSwitchUpgrade, // (InformationalResponse, Switch::SwitchUpgrade) + NormalResponseSwitchConnect, // (NormalResponse, Switch::SwitchConnect) +} + +impl From<&Event> for EventType { + fn from(value: &Event) -> Self { + match value { + Event::Request(_) => EventType::Request, + Event::NormalResponse(_) => EventType::NormalResponse, + Event::InformationalResponse(_) => EventType::InformationalResponse, + Event::Data(_) => EventType::Data, + Event::EndOfMessage(_) => EventType::EndOfMessage, + Event::ConnectionClosed(_) => EventType::ConnectionClosed, + Event::NeedData() => EventType::NeedData, + Event::Paused() => EventType::Paused, + } + } +} + +pub struct ConnectionState { + pub keep_alive: bool, + pub pending_switch_proposals: HashSet, + pub states: HashMap, +} + +impl ConnectionState { + pub fn new() -> Self { + ConnectionState { + keep_alive: true, + pending_switch_proposals: HashSet::new(), + states: HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)]), + } + } + + pub fn process_error(&mut self, role: Role) { + self.states.insert(role, State::Error); + self._fire_state_triggered_transitions(); + } + + pub fn process_keep_alive_disabled(&mut self) { + self.keep_alive = false; + self._fire_state_triggered_transitions(); + } + + pub fn process_client_switch_proposal(&mut self, switch_event: Switch) { + self.pending_switch_proposals.insert(switch_event); + self._fire_state_triggered_transitions(); + } + + pub fn process_event( + &mut self, + role: Role, + event_type: EventType, + server_switch_event: Option, + ) -> Result<(), ProtocolError> { + let mut _event_type = event_type; + if let Some(server_switch_event) = server_switch_event { + assert_eq!(role, Role::Server); + if !self.pending_switch_proposals.contains(&server_switch_event) { + return Err(ProtocolError::LocalProtocolError( + format!( + "Received server {:?} event without a pending proposal", + server_switch_event + ) + .into(), + )); + } + _event_type = match (event_type, server_switch_event) { + (EventType::Request, Switch::Client) => EventType::RequestClient, + (EventType::NormalResponse, Switch::SwitchConnect) => { + EventType::NormalResponseSwitchConnect + } + (EventType::InformationalResponse, Switch::SwitchUpgrade) => { + EventType::InformationalResponseSwitchUpgrade + } + _ => panic!( + "Can't handle event type {:?} when role={:?} and state={:?}", + _event_type, role, self.states[&role] + ), + }; + } + if server_switch_event.is_none() && _event_type == EventType::NormalResponse { + self.pending_switch_proposals.clear(); + } + self._fire_event_triggered_transitions(role, _event_type)?; + if _event_type == EventType::Request { + assert_eq!(role, Role::Client); + self._fire_event_triggered_transitions(Role::Server, EventType::RequestClient)? + } + self._fire_state_triggered_transitions(); + Ok(()) + } + + fn _fire_event_triggered_transitions( + &mut self, + role: Role, + event_type: EventType, + ) -> Result<(), ProtocolError> { + let state = self.states[&role]; + let new_state = match (role, state, event_type) { + (Role::Client, State::Idle, EventType::Request) => State::SendBody, + (Role::Client, State::Idle, EventType::ConnectionClosed) => State::Closed, + (Role::Client, State::SendBody, EventType::Data) => State::SendBody, + (Role::Client, State::SendBody, EventType::EndOfMessage) => State::Done, + (Role::Client, State::Done, EventType::ConnectionClosed) => State::Closed, + (Role::Client, State::MustClose, EventType::ConnectionClosed) => State::Closed, + (Role::Client, State::Closed, EventType::ConnectionClosed) => State::Closed, + + (Role::Server, State::Idle, EventType::ConnectionClosed) => State::Closed, + (Role::Server, State::Idle, EventType::NormalResponse) => State::SendBody, + (Role::Server, State::Idle, EventType::RequestClient) => State::SendResponse, + (Role::Server, State::SendResponse, EventType::InformationalResponse) => { + State::SendResponse + } + (Role::Server, State::SendResponse, EventType::NormalResponse) => State::SendBody, + (Role::Server, State::SendResponse, EventType::InformationalResponseSwitchUpgrade) => { + State::SwitchedProtocol + } + (Role::Server, State::SendResponse, EventType::NormalResponseSwitchConnect) => { + State::SwitchedProtocol + } + (Role::Server, State::SendBody, EventType::Data) => State::SendBody, + (Role::Server, State::SendBody, EventType::EndOfMessage) => State::Done, + (Role::Server, State::Done, EventType::ConnectionClosed) => State::Closed, + (Role::Server, State::MustClose, EventType::ConnectionClosed) => State::Closed, + (Role::Server, State::Closed, EventType::ConnectionClosed) => State::Closed, + _ => { + return Err(ProtocolError::LocalProtocolError( + format!( + "Can't handle event type {:?} when role={:?} and state={:?}", + event_type, role, state + ) + .into(), + )) + } + }; + self.states.insert(role, new_state); + Ok(()) + } + + fn _fire_state_triggered_transitions(&mut self) { + loop { + let start_states = self.states.clone(); + + if self.pending_switch_proposals.len() > 0 { + if self.states[&Role::Client] == State::Done { + self.states.insert(Role::Client, State::MightSwitchProtocol); + } + } + + if self.pending_switch_proposals.is_empty() { + if self.states[&Role::Client] == State::MightSwitchProtocol { + self.states.insert(Role::Client, State::Done); + } + } + + if !self.keep_alive { + for role in &[Role::Client, Role::Server] { + if self.states[role] == State::Done { + self.states.insert(*role, State::MustClose); + } + } + } + + let joint_state = (self.states[&Role::Client], self.states[&Role::Server]); + let changes = match joint_state { + (State::MightSwitchProtocol, State::SwitchedProtocol) => { + vec![(Role::Client, State::SwitchedProtocol)] + } + (State::Closed, State::Done) => { + vec![(Role::Server, State::MustClose)] + } + (State::Closed, State::Idle) => { + vec![(Role::Server, State::MustClose)] + } + (State::Error, State::Done) => vec![(Role::Server, State::MustClose)], + (State::Done, State::Closed) => { + vec![(Role::Client, State::MustClose)] + } + (State::Idle, State::Closed) => { + vec![(Role::Client, State::MustClose)] + } + (State::Done, State::Error) => vec![(Role::Client, State::MustClose)], + _ => vec![], + }; + for (role, new_state) in changes { + self.states.insert(role, new_state); + } + + if self.states == start_states { + return; + } + } + } + + pub fn start_next_cycle(&mut self) -> Result<(), ProtocolError> { + if self.states != HashMap::from([(Role::Client, State::Done), (Role::Server, State::Done)]) + { + return Err(ProtocolError::LocalProtocolError( + format!("Not in a reusable state. self.states={:?}", self.states).into(), + )); + } + assert!(self.keep_alive); + assert!(self.pending_switch_proposals.is_empty()); + self.states.clear(); + self.states.insert(Role::Client, State::Idle); + self.states.insert(Role::Server, State::Idle); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connection_state() { + let mut cs = ConnectionState::new(); + + // Basic event-triggered transitions + + assert_eq!( + cs.states, + HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)]) + ); + + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + // The SERVER-Request special case: + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + + // Illegal transitions raise an error and nothing happens + cs.process_event(Role::Client, EventType::Request, None) + .expect_err("Expected LocalProtocolError"); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::InformationalResponse, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendBody) + ]) + ); + + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([(Role::Client, State::Done), (Role::Server, State::Done)]) + ); + + // State-triggered transition + + cs.process_event(Role::Server, EventType::ConnectionClosed, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MustClose), + (Role::Server, State::Closed) + ]) + ); + } + + #[test] + fn test_connection_state_keep_alive() { + // keep_alive = False + let mut cs = ConnectionState::new(); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_keep_alive_disabled(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MustClose), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MustClose), + (Role::Server, State::MustClose) + ]) + ); + } + + #[test] + fn test_connection_state_keep_alive_in_done() { + // Check that if keep_alive is disabled when the CLIENT is already in DONE, + // then this is sufficient to immediately trigger the DONE -> MUST_CLOSE + // transition + let mut cs = ConnectionState::new(); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!(cs.states[&Role::Client], State::Done); + cs.process_keep_alive_disabled(); + assert_eq!(cs.states[&Role::Client], State::MustClose); + } + + #[test] + fn test_connection_state_switch_denied() { + for switch_type in [Switch::SwitchConnect, Switch::SwitchUpgrade] { + for deny_early in [true, false] { + let mut cs = ConnectionState::new(); + cs.process_client_switch_proposal(switch_type); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::Data, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + + assert!(cs.pending_switch_proposals.contains(&switch_type)); + + if deny_early { + // before client reaches DONE + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + assert!(cs.pending_switch_proposals.is_empty()); + } + + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + + if deny_early { + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::Done), + (Role::Server, State::SendBody) + ]) + ); + } else { + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MightSwitchProtocol), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::InformationalResponse, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MightSwitchProtocol), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::Done), + (Role::Server, State::SendBody) + ]) + ); + assert!(cs.pending_switch_proposals.is_empty()); + } + } + } + } + + #[test] + fn test_connection_state_protocol_switch_accepted() { + for switch_event in [Switch::SwitchUpgrade, Switch::SwitchConnect] { + let mut cs = ConnectionState::new(); + cs.process_client_switch_proposal(switch_event); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::Data, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MightSwitchProtocol), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event(Role::Server, EventType::InformationalResponse, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MightSwitchProtocol), + (Role::Server, State::SendResponse) + ]) + ); + + cs.process_event( + Role::Server, + match switch_event { + Switch::SwitchUpgrade => EventType::InformationalResponse, + Switch::SwitchConnect => EventType::NormalResponse, + _ => panic!(), + }, + Some(switch_event), + ) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SwitchedProtocol), + (Role::Server, State::SwitchedProtocol) + ]) + ); + } + } + + #[test] + fn test_connection_state_double_protocol_switch() { + // CONNECT + Upgrade is legal! Very silly, but legal. So we support + // it. Because sometimes doing the silly thing is easier than not. + for server_switch in [ + None, + Some(Switch::SwitchUpgrade), + Some(Switch::SwitchConnect), + ] { + let mut cs = ConnectionState::new(); + cs.process_client_switch_proposal(Switch::SwitchUpgrade); + cs.process_client_switch_proposal(Switch::SwitchConnect); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::MightSwitchProtocol), + (Role::Server, State::SendResponse) + ]) + ); + cs.process_event( + Role::Server, + match server_switch { + Some(Switch::SwitchUpgrade) => EventType::InformationalResponse, + Some(Switch::SwitchConnect) => EventType::NormalResponse, + None => EventType::NormalResponse, + _ => panic!(), + }, + server_switch, + ) + .unwrap(); + if server_switch.is_none() { + assert_eq!( + cs.states, + HashMap::from([(Role::Client, State::Done), (Role::Server, State::SendBody)]) + ); + } else { + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SwitchedProtocol), + (Role::Server, State::SwitchedProtocol) + ]) + ); + } + } + } + + #[test] + fn test_connection_state_inconsistent_protocol_switch() { + for (client_switches, server_switch) in [ + (vec![], Switch::SwitchUpgrade), + (vec![], Switch::SwitchConnect), + (vec![Switch::SwitchUpgrade], Switch::SwitchConnect), + (vec![Switch::SwitchConnect], Switch::SwitchUpgrade), + ] { + let mut cs = ConnectionState::new(); + for client_switch in client_switches.clone() { + cs.process_client_switch_proposal(client_switch); + } + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Server, EventType::NormalResponse, Some(server_switch)) + .expect_err("Expected LocalProtocolError"); + } + } + + #[test] + fn test_connection_state_keepalive_protocol_switch_interaction() { + // keep_alive=False + pending_switch_proposals + let mut cs = ConnectionState::new(); + cs.process_client_switch_proposal(Switch::SwitchUpgrade); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_keep_alive_disabled(); + cs.process_event(Role::Client, EventType::Data, None) + .unwrap(); + assert_eq!( + cs.states, + HashMap::from([ + (Role::Client, State::SendBody), + (Role::Server, State::SendResponse) + ]) + ); + } + + #[test] + fn test_connection_state_reuse() { + let mut cs = ConnectionState::new(); + + cs.start_next_cycle() + .expect_err("Expected LocalProtocolError"); + + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + + cs.start_next_cycle() + .expect_err("Expected LocalProtocolError"); + + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + + cs.start_next_cycle().unwrap(); + assert_eq!( + cs.states, + HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)]) + ); + + // No keepalive + + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_keep_alive_disabled(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + + cs.start_next_cycle() + .expect_err("Expected LocalProtocolError"); + + // One side closed + + cs = ConnectionState::new(); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + cs.process_event(Role::Client, EventType::ConnectionClosed, None) + .unwrap(); + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + + cs.start_next_cycle() + .expect_err("Expected LocalProtocolError"); + + // Succesful protocol switch + + cs = ConnectionState::new(); + cs.process_client_switch_proposal(Switch::SwitchUpgrade); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + cs.process_event( + Role::Server, + EventType::InformationalResponse, + Some(Switch::SwitchUpgrade), + ) + .unwrap(); + + cs.start_next_cycle() + .expect_err("Expected LocalProtocolError"); + + // Failed protocol switch + + cs = ConnectionState::new(); + cs.process_client_switch_proposal(Switch::SwitchUpgrade); + cs.process_event(Role::Client, EventType::Request, None) + .unwrap(); + cs.process_event(Role::Client, EventType::EndOfMessage, None) + .unwrap(); + cs.process_event(Role::Server, EventType::NormalResponse, None) + .unwrap(); + cs.process_event(Role::Server, EventType::EndOfMessage, None) + .unwrap(); + + cs.start_next_cycle().unwrap(); + assert_eq!( + cs.states, + HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)]) + ); + } + + #[test] + fn test_server_request_is_illegal() { + // There used to be a bug in how we handled the Request special case that + // made this allowed... + let mut cs = ConnectionState::new(); + cs.process_event(Role::Server, EventType::Request, None) + .expect_err("Expected LocalProtocolError"); + } +} diff --git a/src/_util.rs b/src/_util.rs new file mode 100644 index 0000000..77438b1 --- /dev/null +++ b/src/_util.rs @@ -0,0 +1,110 @@ +#[derive(Debug, PartialEq, Eq)] +pub struct LocalProtocolError { + pub message: String, + pub code: u16, +} + +impl From<(String, u16)> for LocalProtocolError { + fn from(value: (String, u16)) -> Self { + LocalProtocolError { + message: value.0, + code: value.1, + } + } +} + +impl From<(&str, u16)> for LocalProtocolError { + fn from(value: (&str, u16)) -> Self { + LocalProtocolError { + message: value.0.to_string(), + code: value.1, + } + } +} + +impl From for LocalProtocolError { + fn from(value: String) -> Self { + LocalProtocolError { + message: value, + code: 400, + } + } +} + +impl From<&str> for LocalProtocolError { + fn from(value: &str) -> Self { + LocalProtocolError { + message: value.to_string(), + code: 400, + } + } +} + +impl LocalProtocolError { + pub(crate) fn _reraise_as_remote_protocol_error(self) -> RemoteProtocolError { + RemoteProtocolError { + message: self.message, + code: self.code, + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct RemoteProtocolError { + pub message: String, + pub code: u16, +} + +impl From<(String, u16)> for RemoteProtocolError { + fn from(value: (String, u16)) -> Self { + RemoteProtocolError { + message: value.0, + code: value.1, + } + } +} + +impl From<(&str, u16)> for RemoteProtocolError { + fn from(value: (&str, u16)) -> Self { + RemoteProtocolError { + message: value.0.to_string(), + code: value.1, + } + } +} + +impl From for RemoteProtocolError { + fn from(value: String) -> Self { + RemoteProtocolError { + message: value, + code: 400, + } + } +} + +impl From<&str> for RemoteProtocolError { + fn from(value: &str) -> Self { + RemoteProtocolError { + message: value.to_string(), + code: 400, + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum ProtocolError { + LocalProtocolError(LocalProtocolError), + RemoteProtocolError(RemoteProtocolError), +} + +impl From for ProtocolError { + fn from(value: LocalProtocolError) -> Self { + ProtocolError::LocalProtocolError(value) + } +} + +impl From for ProtocolError { + fn from(value: RemoteProtocolError) -> Self { + ProtocolError::RemoteProtocolError(value) + } +} diff --git a/src/_writers.rs b/src/_writers.rs new file mode 100644 index 0000000..be28650 --- /dev/null +++ b/src/_writers.rs @@ -0,0 +1,351 @@ +use crate::{ + Event, + _events::{Request, Response}, + _headers::Headers, + _util::ProtocolError, +}; + +pub type WriterFnMut = dyn FnMut(Event) -> Result, ProtocolError>; + +fn _write_headers(headers: &Headers) -> Result, ProtocolError> { + let mut data_list = Vec::new(); + for (raw_name, name, value) in headers.raw_items() { + if name == b"host" { + data_list.append(&mut raw_name.clone()); + data_list.append(&mut b": ".to_vec()); + data_list.append(&mut value.clone()); + data_list.append(&mut b"\r\n".to_vec()); + } + } + for (raw_name, name, value) in headers.raw_items() { + if name != b"host" { + data_list.append(&mut raw_name.clone()); + data_list.append(&mut b": ".to_vec()); + data_list.append(&mut value.clone()); + data_list.append(&mut b"\r\n".to_vec()); + } + } + data_list.append(&mut b"\r\n".to_vec()); + Ok(data_list) +} + +fn _write_request(request: &Request) -> Result, ProtocolError> { + let mut data_list = Vec::new(); + if request.http_version != b"1.1" { + return Err(ProtocolError::LocalProtocolError( + "I only send HTTP/1.1".into(), + )); + } + data_list.append(&mut request.method.clone()); + data_list.append(&mut b" ".to_vec()); + data_list.append(&mut request.target.clone()); + data_list.append(&mut b" HTTP/1.1\r\n".to_vec()); + data_list.append(&mut (_write_headers(&request.headers)?)); + Ok(data_list) +} + +pub fn write_request(event: Event) -> Result, ProtocolError> { + match event { + Event::Request(request) => _write_request(&request), + _ => panic!("Expected Request event, got {:?}", event), + } +} + +fn _write_response(response: &Response) -> Result, ProtocolError> { + if response.http_version != b"1.1" { + return Err(ProtocolError::LocalProtocolError( + "I only send HTTP/1.1".into(), + )); + } + let status_code = response.status_code.to_string(); + let status_bytes = status_code.as_bytes(); + let mut data_list = Vec::new(); + data_list.append(&mut b"HTTP/1.1 ".to_vec()); + data_list.append(&mut status_bytes.to_vec()); + data_list.append(&mut b" ".to_vec()); + data_list.append(&mut response.reason.clone()); + data_list.append(&mut b"\r\n".to_vec()); + data_list.append(&mut (_write_headers(&response.headers))?); + Ok(data_list) +} + +pub fn write_response(event: Event) -> Result, ProtocolError> { + match event { + Event::NormalResponse(response) => _write_response(&response), + Event::InformationalResponse(response) => _write_response(&response), + _ => panic!("Expected Response event, got {:?}", event), + } +} + +trait BodyWriter { + fn call(&mut self, event: Event) -> Result, ProtocolError> { + match event { + Event::Data(data) => self.send_data(&data.data), + Event::EndOfMessage(eom) => self.send_eom(&eom.headers), + _ => panic!("Unknown event type {:?}", event), + } + } + + fn send_data(&mut self, data: &Vec) -> Result, ProtocolError>; + fn send_eom(&mut self, headers: &Headers) -> Result, ProtocolError>; +} + +struct ContentLengthWriter { + length: isize, +} + +impl BodyWriter for ContentLengthWriter { + fn send_data(&mut self, data: &Vec) -> Result, ProtocolError> { + self.length -= data.len() as isize; + if self.length < 0 { + Err(ProtocolError::LocalProtocolError( + "Too much data for declared Content-Length".into(), + )) + } else { + Ok(data.clone()) + } + } + + fn send_eom(&mut self, headers: &Headers) -> Result, ProtocolError> { + if self.length != 0 { + return Err(ProtocolError::LocalProtocolError( + "Too little data for declared Content-Length".into(), + )); + } + if headers.len() > 0 { + return Err(ProtocolError::LocalProtocolError( + "Content-Length and trailers don't mix".into(), + )); + } + Ok(Vec::new()) + } +} + +pub fn content_length_writer(length: isize) -> impl FnMut(Event) -> Result, ProtocolError> { + let mut writer = ContentLengthWriter { length }; + move |event: Event| writer.call(event) +} + +struct ChunkedWriter; + +impl BodyWriter for ChunkedWriter { + fn send_data(&mut self, data: &Vec) -> Result, ProtocolError> { + // if we encoded 0-length data in the naive way, it would look like an + // end-of-message. + if data.len() == 0 { + return Ok(Vec::new()); + } + // write(format!("{:x}\r\n", data.len()).as_bytes().to_vec()); + // write(data.clone()); + // write(b"\r\n".to_vec()); + let mut data_list = Vec::new(); + data_list.append(&mut format!("{:x}\r\n", data.len()).as_bytes().to_vec()); + data_list.append(&mut data.clone()); + data_list.append(&mut b"\r\n".to_vec()); + Ok(data_list) + } + + fn send_eom(&mut self, headers: &Headers) -> Result, ProtocolError> { + // write(b"0\r\n".to_vec()); + // _write_headers(headers); + let mut data_list = Vec::new(); + data_list.append(&mut b"0\r\n".to_vec()); + data_list.append(&mut (_write_headers(headers))?); + Ok(data_list) + } +} + +pub fn chunked_writer() -> impl FnMut(Event) -> Result, ProtocolError> { + let mut writer = ChunkedWriter {}; + move |event: Event| writer.call(event) +} + +struct Http10Writer; + +impl BodyWriter for Http10Writer { + fn send_data(&mut self, data: &Vec) -> Result, ProtocolError> { + Ok(data.clone()) + } + + fn send_eom(&mut self, headers: &Headers) -> Result, ProtocolError> { + if headers.len() > 0 { + Err(ProtocolError::LocalProtocolError( + "can't send trailers to HTTP/1.0 client".into(), + )) + } else { + Ok(Vec::new()) + } + // no need to close the socket ourselves, that will be taken care of by + // Connection: close machinery + } +} + +pub fn http10_writer() -> impl FnMut(Event) -> Result, ProtocolError> { + let mut writer = Http10Writer {}; + move |event: Event| writer.call(event) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::_events::{Data, EndOfMessage}; + + #[test] + fn test_content_length_writer() { + let mut w = ContentLengthWriter { length: 5 }; + assert_eq!( + w.call(Event::Data(Data { + data: b"123".to_vec(), + ..Default::default() + })) + .unwrap(), + b"123" + ); + assert_eq!( + w.call(Event::Data(Data { + data: b"45".to_vec(), + ..Default::default() + })) + .unwrap(), + b"45" + ); + assert_eq!( + w.call(Event::EndOfMessage(EndOfMessage::default())) + .unwrap(), + b"" + ); + + let mut w = ContentLengthWriter { length: 5 }; + assert!(w + .call(Event::Data(Data { + data: b"123456".to_vec(), + ..Default::default() + })) + .is_err()); + + let mut w = ContentLengthWriter { length: 5 }; + assert_eq!( + w.call(Event::Data(Data { + data: b"123".to_vec(), + ..Default::default() + })) + .unwrap(), + b"123" + ); + assert!(w + .call(Event::Data(Data { + data: b"456".to_vec(), + ..Default::default() + })) + .is_err()); + + let mut w = ContentLengthWriter { length: 5 }; + assert_eq!( + w.call(Event::Data(Data { + data: b"123".to_vec(), + ..Default::default() + })) + .unwrap(), + b"123" + ); + assert!(w + .call(Event::EndOfMessage(EndOfMessage::default())) + .is_err()); + + let mut w = ContentLengthWriter { length: 5 }; + assert_eq!( + w.call(Event::Data(Data { + data: b"123".to_vec(), + ..Default::default() + })) + .unwrap(), + b"123" + ); + assert_eq!( + w.call(Event::Data(Data { + data: b"45".to_vec(), + ..Default::default() + })) + .unwrap(), + b"45" + ); + assert!(w + .call(Event::EndOfMessage(EndOfMessage { + headers: vec![(b"Etag".to_vec(), b"asdf".to_vec())].into(), + })) + .is_err()); + } + + #[test] + fn test_chunked_writer() { + let mut w = ChunkedWriter {}; + assert_eq!( + w.call(Event::Data(Data { + data: b"aaa".to_vec(), + ..Default::default() + })) + .unwrap(), + b"3\r\naaa\r\n" + ); + + assert_eq!( + w.call(Event::Data(Data { + data: b"a".to_vec(), + ..Default::default() + })) + .unwrap(), + b"1\r\na\r\n" + ); + + assert_eq!( + w.call(Event::Data(Data { + data: b"b".to_vec(), + ..Default::default() + })) + .unwrap(), + b"1\r\nb\r\n" + ); + + assert_eq!( + w.call(Event::Data(Data { + data: b"".to_vec(), + ..Default::default() + })) + .unwrap(), + b"" + ); + + assert_eq!( + w.call(Event::EndOfMessage(EndOfMessage { + headers: vec![(b"Etag".to_vec(), b"asdf".to_vec())].into(), + })) + .unwrap(), + b"0\r\nEtag: asdf\r\n\r\n" + ); + } + + #[test] + fn test_http10_writer() { + let mut w = Http10Writer {}; + assert_eq!( + w.call(Event::Data(Data { + data: b"1234".to_vec(), + ..Default::default() + })) + .unwrap(), + b"1234" + ); + assert_eq!( + w.call(Event::EndOfMessage(EndOfMessage::default())) + .unwrap(), + b"" + ); + + let mut w = Http10Writer {}; + assert!(w + .call(Event::EndOfMessage(EndOfMessage { + headers: vec![(b"Etag".to_vec(), b"asdf".to_vec())].into(), + })) + .is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ebc78d5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,15 @@ +mod _abnf; +mod _connection; +mod _events; +mod _headers; +mod _readers; +mod _receivebuffer; +mod _state; +mod _util; +mod _writers; + +pub use _connection::Connection; +pub use _events::{ConnectionClosed, Data, EndOfMessage, Event, Request, Response}; +pub use _headers::Headers; +pub use _state::{EventType, Role, State, Switch}; +pub use _util::{LocalProtocolError, ProtocolError, RemoteProtocolError}; diff --git a/tests/connections.rs b/tests/connections.rs new file mode 100644 index 0000000..77391ea --- /dev/null +++ b/tests/connections.rs @@ -0,0 +1,3133 @@ +mod helper; +use std::collections::HashMap; + +use h11::{ + Connection, ConnectionClosed, Data, EndOfMessage, Event, EventType, Headers, ProtocolError, + Request, Response, Role, +}; +use helper::{get_all_events, receive_and_get, ConnectionPair}; + +#[test] +fn test_connection_basics_and_content_length() { + let mut p = ConnectionPair::new(); + assert_eq!( + p.send( + Role::Client, + vec![Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Content-Length".to_vec(), b"10".to_vec()) + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(),], + None, + ) + .unwrap(), + b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 10\r\n\r\n".to_vec() + ); + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::SendBody), + (Role::Server, h11::State::SendResponse), + ]) + ); + } + assert_eq!(p.conn[&Role::Client].get_our_state(), h11::State::SendBody); + assert_eq!( + p.conn[&Role::Client].get_their_state(), + h11::State::SendResponse + ); + assert_eq!( + p.conn[&Role::Server].get_our_state(), + h11::State::SendResponse + ); + assert_eq!( + p.conn[&Role::Server].get_their_state(), + h11::State::SendBody + ); + assert_eq!(p.conn[&Role::Client].their_http_version, None); + assert_eq!( + p.conn[&Role::Server].their_http_version, + Some(b"1.1".to_vec()) + ); + + assert_eq!( + p.send( + Role::Server, + vec![Response { + status_code: 100, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None + ) + .unwrap(), + b"HTTP/1.1 100 \r\n\r\n".to_vec() + ); + + assert_eq!( + p.send( + Role::Server, + vec![Response { + status_code: 200, + headers: vec![(b"Content-Length".to_vec(), b"11".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None + ) + .unwrap(), + b"HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\n".to_vec() + ); + + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::SendBody), + (Role::Server, h11::State::SendBody), + ]) + ); + } + + assert_eq!( + p.conn[&Role::Client].their_http_version, + Some(b"1.1".to_vec()) + ); + assert_eq!( + p.conn[&Role::Server].their_http_version, + Some(b"1.1".to_vec()) + ); + + assert_eq!( + p.send( + Role::Client, + vec![Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + None + ) + .unwrap(), + b"12345".to_vec() + ); + + assert_eq!( + p.send( + Role::Client, + vec![Data { + data: b"67890".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + Some(vec![ + Data { + data: b"67890".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ]), + ) + .unwrap(), + b"67890".to_vec() + ); + + assert_eq!( + p.send( + Role::Client, + vec![EndOfMessage::default().into()], + Some(vec![]), + ) + .unwrap(), + b"".to_vec() + ); + + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Done), + (Role::Server, h11::State::SendBody), + ]) + ); + } + + assert_eq!( + p.send( + Role::Server, + vec![Data { + data: b"1234567890".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + None + ) + .unwrap(), + b"1234567890".to_vec() + ); + + assert_eq!( + p.send( + Role::Server, + vec![Data { + data: b"1".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + Some(vec![ + Data { + data: b"1".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ]), + ) + .unwrap(), + b"1".to_vec() + ); + + assert_eq!( + p.send( + Role::Server, + vec![EndOfMessage::default().into()], + Some(vec![]), + ) + .unwrap(), + b"".to_vec() + ); + + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Done), + (Role::Server, h11::State::Done), + ]) + ); + } +} + +#[test] +fn test_chunked() { + let mut p = ConnectionPair::new(); + assert_eq!( + p.send( + Role::Client, + vec![Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()) + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(),], + None, + ) + .unwrap(), + b"GET / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n".to_vec() + ); + assert_eq!( + p.send( + Role::Client, + vec![Data { + data: b"1234567890".to_vec(), + chunk_start: true, + chunk_end: true, + } + .into()], + None, + ) + .unwrap(), + b"a\r\n1234567890\r\n".to_vec() + ); + assert_eq!( + p.send( + Role::Client, + vec![Data { + data: b"abcde".to_vec(), + chunk_start: true, + chunk_end: true, + } + .into()], + None, + ) + .unwrap(), + b"5\r\nabcde\r\n".to_vec() + ); + assert_eq!( + p.send(Role::Client, vec![Data::default().into()], Some(vec![]),) + .unwrap(), + b"".to_vec() + ); + assert_eq!( + p.send( + Role::Client, + vec![EndOfMessage { + headers: vec![(b"hello".to_vec(), b"there".to_vec())].into(), + } + .into()], + None, + ) + .unwrap(), + b"0\r\nhello: there\r\n\r\n".to_vec() + ); + + assert_eq!( + p.send( + Role::Server, + vec![Response { + status_code: 200, + headers: vec![ + (b"hello".to_vec(), b"there".to_vec()), + (b"transfer-encoding".to_vec(), b"chunked".to_vec()), + ] + .into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None, + ) + .unwrap(), + b"HTTP/1.1 200 \r\nhello: there\r\ntransfer-encoding: chunked\r\n\r\n".to_vec() + ); + assert_eq!( + p.send( + Role::Server, + vec![Data { + data: b"54321".to_vec(), + chunk_start: true, + chunk_end: true, + } + .into()], + None, + ) + .unwrap(), + b"5\r\n54321\r\n".to_vec() + ); + assert_eq!( + p.send( + Role::Server, + vec![Data { + data: b"12345".to_vec(), + chunk_start: true, + chunk_end: true, + } + .into()], + None, + ) + .unwrap(), + b"5\r\n12345\r\n".to_vec() + ); + assert_eq!( + p.send(Role::Server, vec![EndOfMessage::default().into()], None,) + .unwrap(), + b"0\r\n\r\n".to_vec() + ); + + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Done), + (Role::Server, h11::State::Done), + ]) + ); + } +} + +#[test] +fn test_chunk_boundaries() { + let mut conn = Connection::new(Role::Server, None); + + let request = b"POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n"; + conn.receive_data(request).unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Request(Request { + method: b"POST".to_vec(), + target: b"/".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + }) + ); + assert_eq!(conn.next_event().unwrap(), Event::NeedData {}); + + conn.receive_data(b"5\r\nhello\r\n").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Data(Data { + data: b"hello".to_vec(), + chunk_start: true, + chunk_end: true, + }) + ); + + conn.receive_data(b"5\r\nhel").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Data(Data { + data: b"hel".to_vec(), + chunk_start: true, + chunk_end: false, + }) + ); + + conn.receive_data(b"l").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Data(Data { + data: b"l".to_vec(), + chunk_start: false, + chunk_end: false, + }) + ); + + conn.receive_data(b"o\r\n").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Data(Data { + data: b"o".to_vec(), + chunk_start: false, + chunk_end: true, + }) + ); + + conn.receive_data(b"5\r\nhello").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::Data(Data { + data: b"hello".to_vec(), + chunk_start: true, + chunk_end: true, + }) + ); + + conn.receive_data(b"\r\n").unwrap(); + assert_eq!(conn.next_event().unwrap(), Event::NeedData {}); + + conn.receive_data(b"0\r\n\r\n").unwrap(); + assert_eq!( + conn.next_event().unwrap(), + Event::EndOfMessage(EndOfMessage { + headers: vec![].into(), + }) + ); +} + +// def test_client_talking_to_http10_server() -> None: +// c = Connection(CLIENT) +// c.send(Request(method="GET", target="/", headers=[("Host", "example.com")])) +// c.send(EndOfMessage()) +// assert c.our_state is DONE +// # No content-length, so Http10 framing for body +// assert receive_and_get(c, b"HTTP/1.0 200 OK\r\n\r\n") == [ +// Response(status_code=200, headers=[], http_version="1.0", reason=b"OK") # type: ignore[arg-type] +// ] +// assert c.our_state is MUST_CLOSE +// assert receive_and_get(c, b"12345") == [Data(data=b"12345")] +// assert receive_and_get(c, b"67890") == [Data(data=b"67890")] +// assert receive_and_get(c, b"") == [EndOfMessage(), ConnectionClosed()] +// assert c.their_state is CLOSED + +#[test] +fn test_client_talking_to_http10_server() { + let mut c = Connection::new(Role::Client, None); + c.send( + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"example.com".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + assert_eq!(c.get_our_state(), h11::State::Done); + assert_eq!( + receive_and_get(&mut c, b"HTTP/1.0 200 OK\r\n\r\n").unwrap(), + vec![Event::NormalResponse(Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.0".to_vec(), + reason: b"OK".to_vec(), + })], + ); + assert_eq!(c.get_our_state(), h11::State::MustClose); + assert_eq!( + receive_and_get(&mut c, b"12345").unwrap(), + vec![Event::Data(Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + })], + ); + assert_eq!( + receive_and_get(&mut c, b"67890").unwrap(), + vec![Event::Data(Data { + data: b"67890".to_vec(), + chunk_start: false, + chunk_end: false, + })], + ); + assert_eq!( + receive_and_get(&mut c, b"").unwrap(), + vec![ + Event::EndOfMessage(EndOfMessage::default()), + Event::ConnectionClosed(ConnectionClosed::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::Closed); +} + +// def test_server_talking_to_http10_client() -> None: +// c = Connection(SERVER) +// # No content-length, so no body +// # NB: no host header +// assert receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") == [ +// Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] +// EndOfMessage(), +// ] +// assert c.their_state is MUST_CLOSE + +// # We automatically Connection: close back at them +// assert ( +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" +// ) + +// assert c.send(Data(data=b"12345")) == b"12345" +// assert c.send(EndOfMessage()) == b"" +// assert c.our_state is MUST_CLOSE + +// # Check that it works if they do send Content-Length +// c = Connection(SERVER) +// # NB: no host header +// assert receive_and_get(c, b"POST / HTTP/1.0\r\nContent-Length: 10\r\n\r\n1") == [ +// Request( +// method="POST", +// target="/", +// headers=[("Content-Length", "10")], +// http_version="1.0", +// ), +// Data(data=b"1"), +// ] +// assert receive_and_get(c, b"234567890") == [Data(data=b"234567890"), EndOfMessage()] +// assert c.their_state is MUST_CLOSE +// assert receive_and_get(c, b"") == [ConnectionClosed()] + +#[test] +fn test_server_talking_to_http10_client() { + let mut c = Connection::new(Role::Server, None); + // No content-length, so no body + // NB: no host header + assert_eq!( + receive_and_get(&mut c, b"GET / HTTP/1.0\r\n\r\n").unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![].into(), + http_version: b"1.0".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::MustClose); + + // We automatically Connection: close back at them + assert_eq!( + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n".to_vec() + ); + + assert_eq!( + c.send( + Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into() + ) + .unwrap() + .unwrap(), + b"12345".to_vec() + ); + assert_eq!( + c.send(EndOfMessage::default().into()).unwrap().unwrap(), + b"".to_vec() + ); + assert_eq!(c.get_our_state(), h11::State::MustClose); + + // Check that it works if they do send Content-Length + let mut c = Connection::new(Role::Server, None); + // NB: no host header + assert_eq!( + receive_and_get(&mut c, b"POST / HTTP/1.0\r\nContent-Length: 10\r\n\r\n1").unwrap(), + vec![ + Event::Request(Request { + method: b"POST".to_vec(), + target: b"/".to_vec(), + headers: vec![(b"Content-Length".to_vec(), b"10".to_vec())].into(), + http_version: b"1.0".to_vec(), + }), + Event::Data(Data { + data: b"1".to_vec(), + chunk_start: false, + chunk_end: false, + }), + ], + ); + assert_eq!( + receive_and_get(&mut c, b"234567890").unwrap(), + vec![ + Event::Data(Data { + data: b"234567890".to_vec(), + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::MustClose); + assert_eq!( + receive_and_get(&mut c, b"").unwrap(), + vec![Event::ConnectionClosed(ConnectionClosed::default())], + ); +} + +// def test_automatic_transfer_encoding_in_response() -> None: +// # Check that in responses, the user can specify either Transfer-Encoding: +// # chunked or no framing at all, and in both cases we automatically select +// # the right option depending on whether the peer speaks HTTP/1.0 or +// # HTTP/1.1 +// for user_headers in [ +// [("Transfer-Encoding", "chunked")], +// [], +// # In fact, this even works if Content-Length is set, +// # because if both are set then Transfer-Encoding wins +// [("Transfer-Encoding", "chunked"), ("Content-Length", "100")], +// ]: +// user_headers = cast(List[Tuple[str, str]], user_headers) +// p = ConnectionPair() +// p.send( +// CLIENT, +// [ +// Request(method="GET", target="/", headers=[("Host", "example.com")]), +// EndOfMessage(), +// ], +// ) +// # When speaking to HTTP/1.1 client, all of the above cases get +// # normalized to Transfer-Encoding: chunked +// p.send( +// SERVER, +// Response(status_code=200, headers=user_headers), +// expect=Response( +// status_code=200, headers=[("Transfer-Encoding", "chunked")] +// ), +// ) + +// # When speaking to HTTP/1.0 client, all of the above cases get +// # normalized to no-framing-headers +// c = Connection(SERVER) +// receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") +// assert ( +// c.send(Response(status_code=200, headers=user_headers)) +// == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" +// ) +// assert c.send(Data(data=b"12345")) == b"12345" + +#[test] +fn test_automatic_transfer_encoding_in_response() { + // Check that in responses, the user can specify either Transfer-Encoding: + // chunked or no framing at all, and in both cases we automatically select + // the right option depending on whether the peer speaks HTTP/1.0 or + // HTTP/1.1 + for user_headers in vec![ + vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())], + vec![], + // In fact, this even works if Content-Length is set, + // because if both are set then Transfer-Encoding wins + vec![ + (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()), + (b"Content-Length".to_vec(), b"100".to_vec()), + ], + ] { + let mut p = ConnectionPair::new(); + p.send( + Role::Client, + vec![ + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"example.com".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + // When speaking to HTTP/1.1 client, all of the above cases get + // normalized to Transfer-Encoding: chunked + p.send( + Role::Server, + vec![Response { + status_code: 200, + headers: user_headers.clone().into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + Some(vec![Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()]), + ) + .unwrap(); + + // When speaking to HTTP/1.0 client, all of the above cases get + // normalized to no-framing-headers + let mut c = Connection::new(Role::Server, None); + receive_and_get(&mut c, b"GET / HTTP/1.0\r\n\r\n").unwrap(); + assert_eq!( + c.send( + Response { + status_code: 200, + headers: user_headers.clone().into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n".to_vec() + ); + assert_eq!( + c.send( + Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into() + ) + .unwrap() + .unwrap(), + b"12345".to_vec() + ); + } +} + +// def test_automagic_connection_close_handling() -> None: +// p = ConnectionPair() +// # If the user explicitly sets Connection: close, then we notice and +// # respect it +// p.send( +// CLIENT, +// [ +// Request( +// method="GET", +// target="/", +// headers=[("Host", "example.com"), ("Connection", "close")], +// ), +// EndOfMessage(), +// ], +// ) +// for conn in p.conns: +// assert conn.states[CLIENT] is MUST_CLOSE +// # And if the client sets it, the server automatically echoes it back +// p.send( +// SERVER, +// # no header here... +// [Response(status_code=204, headers=[]), EndOfMessage()], # type: ignore[arg-type] +// # ...but oh look, it arrived anyway +// expect=[ +// Response(status_code=204, headers=[("connection", "close")]), +// EndOfMessage(), +// ], +// ) +// for conn in p.conns: +// assert conn.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} + +#[test] +fn test_automagic_connection_close_handling() { + let mut p = ConnectionPair::new(); + // If the user explicitly sets Connection: close, then we notice and + // respect it + p.send( + Role::Client, + vec![ + Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Connection".to_vec(), b"close".to_vec()), + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states()[&Role::Client], + h11::State::MustClose + ); + } + // And if the client sets it, the server automatically echoes it back + p.send( + Role::Server, + vec![ + Response { + status_code: 204, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + EndOfMessage::default().into(), + ], + Some(vec![ + Response { + status_code: 204, + headers: vec![(b"connection".to_vec(), b"close".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + EndOfMessage::default().into(), + ]), + ) + .unwrap(); + for (_, connection) in &p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::MustClose), + (Role::Server, h11::State::MustClose), + ]) + ); + } +} + +// def test_100_continue() -> None: +// def setup() -> ConnectionPair: +// p = ConnectionPair() +// p.send( +// CLIENT, +// Request( +// method="GET", +// target="/", +// headers=[ +// ("Host", "example.com"), +// ("Content-Length", "100"), +// ("Expect", "100-continue"), +// ], +// ), +// ) +// for conn in p.conns: +// assert conn.get_client_is_waiting_for_100_continue() +// assert not p.conn[CLIENT].they_are_waiting_for_100_continue +// assert p.conn[SERVER].they_are_waiting_for_100_continue +// return p + +// # Disabled by 100 Continue +// p = setup() +// p.send(SERVER, InformationalResponse(status_code=100, headers=[])) # type: ignore[arg-type] +// for conn in p.conns: +// assert not conn.get_client_is_waiting_for_100_continue() +// assert not conn.they_are_waiting_for_100_continue + +// # Disabled by a real response +// p = setup() +// p.send( +// SERVER, Response(status_code=200, headers=[("Transfer-Encoding", "chunked")]) +// ) +// for conn in p.conns: +// assert not conn.get_client_is_waiting_for_100_continue() +// assert not conn.they_are_waiting_for_100_continue + +// # Disabled by the client going ahead and sending stuff anyway +// p = setup() +// p.send(CLIENT, Data(data=b"12345")) +// for conn in p.conns: +// assert not conn.get_client_is_waiting_for_100_continue() +// assert not conn.they_are_waiting_for_100_continue + +#[test] +fn test_100_continue() { + fn setup() -> ConnectionPair { + let mut p = ConnectionPair::new(); + p.send( + Role::Client, + vec![Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"example.com".to_vec()), + (b"Content-Length".to_vec(), b"100".to_vec()), + (b"Expect".to_vec(), b"100-continue".to_vec()), + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(connection.get_client_is_waiting_for_100_continue()); + } + assert!(!p.conn[&Role::Client].get_they_are_waiting_for_100_continue()); + assert!(p.conn[&Role::Server].get_they_are_waiting_for_100_continue()); + p + } + + // Disabled by 100 Continue + let mut p = setup(); + p.send( + Role::Server, + vec![Response { + status_code: 100, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(!connection.get_client_is_waiting_for_100_continue()); + assert!(!connection.get_they_are_waiting_for_100_continue()); + } + + // Disabled by a real response + let mut p = setup(); + p.send( + Role::Server, + vec![Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(!connection.get_client_is_waiting_for_100_continue()); + assert!(!connection.get_they_are_waiting_for_100_continue()); + } + + // Disabled by the client going ahead and sending stuff anyway + let mut p = setup(); + p.send( + Role::Client, + vec![Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(!connection.get_client_is_waiting_for_100_continue()); + assert!(!connection.get_they_are_waiting_for_100_continue()); + } + + // Disabled by the client going ahead and sending stuff anyway + let mut p = setup(); + p.send( + Role::Client, + vec![Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(!connection.get_client_is_waiting_for_100_continue()); + assert!(!connection.get_they_are_waiting_for_100_continue()); + } + + // Disabled by the client going ahead and sending stuff anyway + let mut p = setup(); + p.send( + Role::Client, + vec![Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into()], + None, + ) + .unwrap(); + for (_, connection) in &p.conn { + assert!(!connection.get_client_is_waiting_for_100_continue()); + assert!(!connection.get_they_are_waiting_for_100_continue()); + } +} + +// def test_max_incomplete_event_size_countermeasure() -> None: +// # Infinitely long headers are definitely not okay +// c = Connection(SERVER) +// c.receive_data(b"GET / HTTP/1.0\r\nEndless: ") +// assert c.next_event() is NEED_DATA +// with pytest.raises(RemoteProtocolError): +// while True: +// c.receive_data(b"a" * 1024) +// c.next_event() + +// # Checking that the same header is accepted / rejected depending on the +// # max_incomplete_event_size setting: +// c = Connection(SERVER, max_incomplete_event_size=5000) +// c.receive_data(b"GET / HTTP/1.0\r\nBig: ") +// c.receive_data(b"a" * 4000) +// c.receive_data(b"\r\n\r\n") +// assert get_all_events(c) == [ +// Request( +// method="GET", target="/", http_version="1.0", headers=[("big", "a" * 4000)] +// ), +// EndOfMessage(), +// ] + +// c = Connection(SERVER, max_incomplete_event_size=4000) +// c.receive_data(b"GET / HTTP/1.0\r\nBig: ") +// c.receive_data(b"a" * 4000) +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +// # Temporarily exceeding the size limit is fine, as long as its done with +// # complete events: +// c = Connection(SERVER, max_incomplete_event_size=5000) +// c.receive_data(b"GET / HTTP/1.0\r\nContent-Length: 10000") +// c.receive_data(b"\r\n\r\n" + b"a" * 10000) +// assert get_all_events(c) == [ +// Request( +// method="GET", +// target="/", +// http_version="1.0", +// headers=[("Content-Length", "10000")], +// ), +// Data(data=b"a" * 10000), +// EndOfMessage(), +// ] + +// c = Connection(SERVER, max_incomplete_event_size=100) +// # Two pipelined requests to create a way-too-big receive buffer... but +// # it's fine because we're not checking +// c.receive_data( +// b"GET /1 HTTP/1.1\r\nHost: a\r\n\r\n" +// b"GET /2 HTTP/1.1\r\nHost: b\r\n\r\n" + b"X" * 1000 +// ) +// assert get_all_events(c) == [ +// Request(method="GET", target="/1", headers=[("host", "a")]), +// EndOfMessage(), +// ] +// # Even more data comes in, still no problem +// c.receive_data(b"X" * 1000) +// # We can respond and reuse to get the second pipelined request +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// c.start_next_cycle() +// assert get_all_events(c) == [ +// Request(method="GET", target="/2", headers=[("host", "b")]), +// EndOfMessage(), +// ] +// # But once we unpause and try to read the next message, and find that it's +// # incomplete and the buffer is *still* way too large, then *that's* a +// # problem: +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// c.start_next_cycle() +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_max_incomplete_event_size_countermeasure() { + // Infinitely long headers are definitely not okay + let mut c = Connection::new(Role::Server, Some(5000)); + c.receive_data(b"GET / HTTP/1.0\r\nEndless: ").unwrap(); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + + // Checking that the same header is accepted / rejected depending on the + // max_incomplete_event_size setting: + let mut c = Connection::new(Role::Server, Some(5000)); + c.receive_data(b"GET / HTTP/1.0\r\nBig: ").unwrap(); + c.receive_data(&vec![b'a'; 4000]).unwrap(); + c.receive_data(b"\r\n\r\n").unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![(b"Big".to_vec(), vec![b'a'; 4000])].into(), + http_version: b"1.0".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + + let mut c = Connection::new(Role::Server, Some(4000)); + c.receive_data(b"GET / HTTP/1.0\r\nBig: ").unwrap(); + c.receive_data(&vec![b'a'; 4000]).unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + + // Temporarily exceeding the size limit is fine, as long as its done with + // complete events: + let mut c = Connection::new(Role::Server, Some(5000)); + c.receive_data(b"GET / HTTP/1.0\r\nContent-Length: 10000") + .unwrap(); + c.receive_data(b"\r\n\r\n").unwrap(); + c.receive_data(&vec![b'a'; 10000]).unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![(b"Content-Length".to_vec(), b"10000".to_vec())].into(), + http_version: b"1.0".to_vec(), + }), + Event::Data(Data { + data: vec![b'a'; 10000], + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + + let mut c = Connection::new(Role::Server, Some(100)); + // Two pipelined requests to create a way-too-big receive buffer... but + // it's fine because we're not checking + c.receive_data( + b"GET /1 HTTP/1.1\r\nHost: a\r\n\r\n" + .to_vec() + .into_iter() + .chain(b"GET /2 HTTP/1.1\r\nHost: b\r\n\r\n".to_vec().into_iter()) + .chain(vec![b'X'; 1000].into_iter()) + .collect::>() + .as_slice(), + ) + .unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/1".to_vec(), + headers: vec![(b"Host".to_vec(), b"a".to_vec())].into(), + http_version: b"1.1".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + // Even more data comes in, still no problem + c.receive_data(&vec![b'X'; 1000]).unwrap(); + // We can respond and reuse to get the second pipelined request + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + c.start_next_cycle().unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/2".to_vec(), + headers: vec![(b"Host".to_vec(), b"b".to_vec())].into(), + http_version: b"1.1".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + // But once we unpause and try to read the next message, and find that it's + // incomplete and the buffer is *still* way too large, then *that's* a + // problem: + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + c.start_next_cycle().unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + + // Check that we can still send data after this happens + let mut c = Connection::new(Role::Server, Some(100)); + // Two pipelined requests to create a way-too-big receive buffer... but + // it's fine because we're not checking + c.receive_data( + b"GET /1 HTTP/1.1\r\nHost: a\r\n\r\n" + .to_vec() + .into_iter() + .chain(b"GET /2 HTTP/1.1\r\nHost: b\r\n\r\n".to_vec().into_iter()) + .chain(vec![b'X'; 1000].into_iter()) + .collect::>() + .as_slice(), + ) + .unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/1".to_vec(), + headers: vec![(b"Host".to_vec(), b"a".to_vec())].into(), + http_version: b"1.1".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + // Even more data comes in, still no problem + c.receive_data(&vec![b'X'; 1000]).unwrap(); + // We can respond and reuse to get the second pipelined request + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + c.start_next_cycle().unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/2".to_vec(), + headers: vec![(b"Host".to_vec(), b"b".to_vec())].into(), + http_version: b"1.1".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + // But once we unpause and try to read the next message, and find that it's + // incomplete and the buffer is *still* way too large, then *that's* a + // problem: + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + c.start_next_cycle().unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} + +// def test_reuse_simple() -> None: +// p = ConnectionPair() +// p.send( +// CLIENT, +// [Request(method="GET", target="/", headers=[("Host", "a")]), EndOfMessage()], +// ) +// p.send( +// SERVER, +// [ +// Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), +// EndOfMessage(), +// ], +// ) +// for conn in p.conns: +// assert conn.states == {CLIENT: DONE, SERVER: DONE} +// conn.start_next_cycle() + +// p.send( +// CLIENT, +// [ +// Request(method="DELETE", target="/foo", headers=[("Host", "a")]), +// EndOfMessage(), +// ], +// ) +// p.send( +// SERVER, +// [ +// Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), +// EndOfMessage(), +// ], +// ) + +#[test] +fn test_reuse_simple() { + let mut p = ConnectionPair::new(); + p.send( + Role::Client, + vec![ + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"a".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + p.send( + Role::Server, + vec![ + Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Done), + (Role::Server, h11::State::Done), + ]) + ); + connection.start_next_cycle().unwrap(); + } + + p.send( + Role::Client, + vec![ + Request::new( + b"DELETE".to_vec(), + vec![(b"Host".to_vec(), b"a".to_vec())].into(), + b"/foo".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + p.send( + Role::Server, + vec![ + Response { + status_code: 404, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); +} + +// def test_pipelining() -> None: +// # Client doesn't support pipelining, so we have to do this by hand +// c = Connection(SERVER) +// assert c.next_event() is NEED_DATA +// # 3 requests all bunched up +// c.receive_data( +// b"GET /1 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n" +// b"12345" +// b"GET /2 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n" +// b"67890" +// b"GET /3 HTTP/1.1\r\nHost: a.com\r\n\r\n" +// ) +// assert get_all_events(c) == [ +// Request( +// method="GET", +// target="/1", +// headers=[("Host", "a.com"), ("Content-Length", "5")], +// ), +// Data(data=b"12345"), +// EndOfMessage(), +// ] +// assert c.their_state is DONE +// assert c.our_state is SEND_RESPONSE + +// assert c.next_event() is PAUSED + +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// assert c.their_state is DONE +// assert c.our_state is DONE + +// c.start_next_cycle() + +// assert get_all_events(c) == [ +// Request( +// method="GET", +// target="/2", +// headers=[("Host", "a.com"), ("Content-Length", "5")], +// ), +// Data(data=b"67890"), +// EndOfMessage(), +// ] +// assert c.next_event() is PAUSED +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// c.start_next_cycle() + +// assert get_all_events(c) == [ +// Request(method="GET", target="/3", headers=[("Host", "a.com")]), +// EndOfMessage(), +// ] +// # Doesn't pause this time, no trailing data +// assert c.next_event() is NEED_DATA +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) + +// # Arrival of more data triggers pause +// assert c.next_event() is NEED_DATA +// c.receive_data(b"SADF") +// assert c.next_event() is PAUSED +// assert c.trailing_data == (b"SADF", False) +// # If EOF arrives while paused, we don't see that either: +// c.receive_data(b"") +// assert c.trailing_data == (b"SADF", True) +// assert c.next_event() is PAUSED +// c.receive_data(b"") +// assert c.next_event() is PAUSED + +#[test] +fn test_pipelining() { + // Client doesn't support pipelining, so we have to do this by hand + let mut c = Connection::new(Role::Server, None); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + + // 3 requests all bunched up + c.receive_data( + &vec![ + b"GET /1 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n".to_vec(), + b"12345".to_vec(), + b"GET /2 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n".to_vec(), + b"67890".to_vec(), + b"GET /3 HTTP/1.1\r\nHost: a.com\r\n\r\n".to_vec(), + ] + .into_iter() + .flatten() + .collect::>(), + ) + .unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/1".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"a.com".to_vec()), + (b"Content-Length".to_vec(), b"5".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + }), + Event::Data(Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::Done); + assert_eq!(c.get_our_state(), h11::State::SendResponse); + + assert_eq!(c.next_event().unwrap(), Event::Paused {}); + + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + assert_eq!(c.get_their_state(), h11::State::Done); + assert_eq!(c.get_our_state(), h11::State::Done); + + c.start_next_cycle().unwrap(); + + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/2".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"a.com".to_vec()), + (b"Content-Length".to_vec(), b"5".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + }), + Event::Data(Data { + data: b"67890".to_vec(), + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + assert_eq!(c.next_event().unwrap(), Event::Paused {}); + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + c.start_next_cycle().unwrap(); + + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/3".to_vec(), + headers: vec![(b"Host".to_vec(), b"a.com".to_vec())].into(), + http_version: b"1.1".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + // Doesn't pause this time, no trailing data + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + + // Arrival of more data triggers pause + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + c.receive_data(b"SADF").unwrap(); + assert_eq!(c.next_event().unwrap(), Event::Paused {}); + assert_eq!(c.get_trailing_data(), (b"SADF".to_vec(), false)); + // If EOF arrives while paused, we don't see that either: + c.receive_data(b"").unwrap(); + assert_eq!(c.get_trailing_data(), (b"SADF".to_vec(), true)); + assert_eq!(c.next_event().unwrap(), Event::Paused {}); + c.receive_data(b"").unwrap(); + assert_eq!(c.next_event().unwrap(), Event::Paused {}); +} + +// def test_protocol_switch() -> None: +// for req, deny, accept in [ +// ( +// Request( +// method="CONNECT", +// target="example.com:443", +// headers=[("Host", "foo"), ("Content-Length", "1")], +// ), +// Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), +// Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), +// ), +// ( +// Request( +// method="GET", +// target="/", +// headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], +// ), +// Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), +// InformationalResponse(status_code=101, headers=[("Upgrade", "a")]), +// ), +// ( +// Request( +// method="CONNECT", +// target="example.com:443", +// headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], +// ), +// Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), +// # Accept CONNECT, not upgrade +// Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), +// ), +// ( +// Request( +// method="CONNECT", +// target="example.com:443", +// headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], +// ), +// Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), +// # Accept Upgrade, not CONNECT +// InformationalResponse(status_code=101, headers=[("Upgrade", "b")]), +// ), +// ]: + +// def setup() -> ConnectionPair: +// p = ConnectionPair() +// p.send(CLIENT, req) +// # No switch-related state change stuff yet; the client has to +// # finish the request before that kicks in +// for conn in p.conns: +// assert conn.states[CLIENT] is SEND_BODY +// p.send(CLIENT, [Data(data=b"1"), EndOfMessage()]) +// for conn in p.conns: +// assert conn.states[CLIENT] is MIGHT_SWITCH_PROTOCOL +// assert p.conn[SERVER].next_event() is PAUSED +// return p + +// # Test deny case +// p = setup() +// p.send(SERVER, deny) +// for conn in p.conns: +// assert conn.states == {CLIENT: DONE, SERVER: SEND_BODY} +// p.send(SERVER, EndOfMessage()) +// # Check that re-use is still allowed after a denial +// for conn in p.conns: +// conn.start_next_cycle() + +// # Test accept case +// p = setup() +// p.send(SERVER, accept) +// for conn in p.conns: +// assert conn.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} +// conn.receive_data(b"123") +// assert conn.next_event() is PAUSED +// conn.receive_data(b"456") +// assert conn.next_event() is PAUSED +// assert conn.trailing_data == (b"123456", False) + +// # Pausing in might-switch, then recovery +// # (weird artificial case where the trailing data actually is valid +// # HTTP for some reason, because this makes it easier to test the state +// # logic) +// p = setup() +// sc = p.conn[SERVER] +// sc.receive_data(b"GET / HTTP/1.0\r\n\r\n") +// assert sc.next_event() is PAUSED +// assert sc.trailing_data == (b"GET / HTTP/1.0\r\n\r\n", False) +// sc.send(deny) +// assert sc.next_event() is PAUSED +// sc.send(EndOfMessage()) +// sc.start_next_cycle() +// assert get_all_events(sc) == [ +// Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] +// EndOfMessage(), +// ] + +// # When we're DONE, have no trailing data, and the connection gets +// # closed, we report ConnectionClosed(). When we're in might-switch or +// # switched, we don't. +// p = setup() +// sc = p.conn[SERVER] +// sc.receive_data(b"") +// assert sc.next_event() is PAUSED +// assert sc.trailing_data == (b"", True) +// p.send(SERVER, accept) +// assert sc.next_event() is PAUSED + +// p = setup() +// sc = p.conn[SERVER] +// sc.receive_data(b"") +// assert sc.next_event() is PAUSED +// sc.send(deny) +// assert sc.next_event() == ConnectionClosed() + +// # You can't send after switching protocols, or while waiting for a +// # protocol switch +// p = setup() +// with pytest.raises(LocalProtocolError): +// p.conn[CLIENT].send( +// Request(method="GET", target="/", headers=[("Host", "a")]) +// ) +// p = setup() +// p.send(SERVER, accept) +// with pytest.raises(LocalProtocolError): +// p.conn[SERVER].send(Data(data=b"123")) + +#[test] +fn test_protocol_switch() { + for (req, deny, accept) in vec![ + ( + Request::new( + b"CONNECT".to_vec(), + vec![ + (b"Host".to_vec(), b"foo".to_vec()), + (b"Content-Length".to_vec(), b"1".to_vec()), + ] + .into(), + b"example.com:443".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + Response { + status_code: 404, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ), + ( + Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"foo".to_vec()), + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Upgrade".to_vec(), b"a, b".to_vec()), + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + Response { + status_code: 101, + headers: vec![(b"Upgrade".to_vec(), b"a".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ), + ( + Request::new( + b"CONNECT".to_vec(), + vec![ + (b"Host".to_vec(), b"foo".to_vec()), + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Upgrade".to_vec(), b"a, b".to_vec()), + ] + .into(), + b"example.com:443".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + Response { + status_code: 404, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + // Accept CONNECT, not upgrade + Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ), + ( + Request::new( + b"CONNECT".to_vec(), + vec![ + (b"Host".to_vec(), b"foo".to_vec()), + (b"Content-Length".to_vec(), b"1".to_vec()), + (b"Upgrade".to_vec(), b"a, b".to_vec()), + ] + .into(), + b"example.com:443".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + Response { + status_code: 404, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + // Accept Upgrade, not CONNECT + Response { + status_code: 101, + headers: vec![(b"Upgrade".to_vec(), b"b".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ), + ] { + let req: Event = req; + let deny: Event = deny; + let accept: Event = accept; + let setup = || { + let mut p = ConnectionPair::new(); + p.send(Role::Client, vec![req.clone()], None).unwrap(); + // No switch-related state change stuff yet; the client has to + // finish the request before that kicks in + for (_, connection) in &mut p.conn { + assert_eq!(connection.get_states()[&Role::Client], h11::State::SendBody); + } + p.send( + Role::Client, + vec![ + Data { + data: b"1".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + EndOfMessage::default().into(), + ], + None, + ) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states()[&Role::Client], + h11::State::MightSwitchProtocol + ); + } + assert_eq!( + p.conn.get_mut(&Role::Server).unwrap().next_event().unwrap(), + Event::Paused {} + ); + return p; + }; + + // Test deny case + let mut p = setup(); + p.send(Role::Server, vec![deny.clone()], None).unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Done), + (Role::Server, h11::State::SendBody) + ]) + ); + } + p.send(Role::Server, vec![EndOfMessage::default().into()], None) + .unwrap(); + // Check that re-use is still allowed after a denial + for (_, connection) in &mut p.conn { + connection.start_next_cycle().unwrap(); + } + + // Test accept case + let mut p = setup(); + p.send(Role::Server, vec![accept.clone()], None).unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::SwitchedProtocol), + (Role::Server, h11::State::SwitchedProtocol) + ]) + ); + connection.receive_data(b"123").unwrap(); + assert_eq!(connection.next_event().unwrap(), Event::Paused {}); + connection.receive_data(b"456").unwrap(); + assert_eq!(connection.next_event().unwrap(), Event::Paused {}); + assert_eq!(connection.get_trailing_data(), (b"123456".to_vec(), false)); + } + + // Pausing in might-switch, then recovery + // (weird artificial case where the trailing data actually is valid + // HTTP for some reason, because this makes it easier to test the state + // logic) + let mut p = setup(); + let sc = p.conn.get_mut(&Role::Server).unwrap(); + sc.receive_data(b"GET / HTTP/1.0\r\n\r\n").unwrap(); + assert_eq!(sc.next_event().unwrap(), Event::Paused {}); + assert_eq!( + sc.get_trailing_data(), + (b"GET / HTTP/1.0\r\n\r\n".to_vec(), false) + ); + sc.send(deny.clone()).unwrap(); + assert_eq!(sc.next_event().unwrap(), Event::Paused {}); + sc.send(EndOfMessage::default().into()).unwrap(); + sc.start_next_cycle().unwrap(); + assert_eq!( + get_all_events(sc).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/".to_vec(), + headers: vec![].into(), + http_version: b"1.0".to_vec(), + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + + // When we're DONE, have no trailing data, and the connection gets + // closed, we report ConnectionClosed(). When we're in might-switch or + // switched, we don't. + let mut p = setup(); + { + let sc = (p.conn).get_mut(&Role::Server).unwrap(); + sc.receive_data(b"").unwrap(); + assert_eq!(sc.next_event().unwrap(), Event::Paused {}); + assert_eq!(sc.get_trailing_data(), (b"".to_vec(), true)); + } + p.send(Role::Server, vec![accept.clone()], None).unwrap(); + assert_eq!( + (p.conn) + .get_mut(&Role::Server) + .unwrap() + .next_event() + .unwrap(), + Event::Paused {} + ); + + let mut p = setup(); + let sc = p.conn.get_mut(&Role::Server).unwrap(); + sc.receive_data(b"").unwrap(); + assert_eq!(sc.next_event().unwrap(), Event::Paused {}); + sc.send(deny).unwrap(); + assert_eq!( + sc.next_event().unwrap(), + Event::ConnectionClosed(ConnectionClosed::default()) + ); + + // You can't send after switching protocols, or while waiting for a + // protocol switch + let mut p = setup(); + let cc = p.conn.get_mut(&Role::Client).unwrap(); + assert!(match cc.send( + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"a".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + ) { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + let mut p = setup(); + p.send(Role::Server, vec![accept], None).unwrap(); + let cc = p.conn.get_mut(&Role::Client).unwrap(); + assert!(match cc.send( + Data { + data: b"123".to_vec(), + chunk_start: false, + chunk_end: false, + } + .into(), + ) { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + } +} + +// def test_close_simple() -> None: +// # Just immediately closing a new connection without anything having +// # happened yet. +// for who_shot_first, who_shot_second in [(CLIENT, SERVER), (SERVER, CLIENT)]: + +// def setup() -> ConnectionPair: +// p = ConnectionPair() +// p.send(who_shot_first, ConnectionClosed()) +// for conn in p.conns: +// assert conn.states == { +// who_shot_first: CLOSED, +// who_shot_second: MUST_CLOSE, +// } +// return p + +// # You can keep putting b"" into a closed connection, and you keep +// # getting ConnectionClosed() out: +// p = setup() +// assert p.conn[who_shot_second].next_event() == ConnectionClosed() +// assert p.conn[who_shot_second].next_event() == ConnectionClosed() +// p.conn[who_shot_second].receive_data(b"") +// assert p.conn[who_shot_second].next_event() == ConnectionClosed() +// # Second party can close... +// p = setup() +// p.send(who_shot_second, ConnectionClosed()) +// for conn in p.conns: +// assert conn.our_state is CLOSED +// assert conn.their_state is CLOSED +// # But trying to receive new data on a closed connection is a +// # RuntimeError (not ProtocolError, because the problem here isn't +// # violation of HTTP, it's violation of physics) +// p = setup() +// with pytest.raises(RuntimeError): +// p.conn[who_shot_second].receive_data(b"123") +// # And receiving new data on a MUST_CLOSE connection is a ProtocolError +// p = setup() +// p.conn[who_shot_first].receive_data(b"GET") +// with pytest.raises(RemoteProtocolError): +// p.conn[who_shot_first].next_event() + +#[test] +fn test_close_simple() { + // Just immediately closing a new connection without anything having + // happened yet. + for (who_shot_first, who_shot_second) in + vec![(Role::Client, Role::Server), (Role::Server, Role::Client)] + { + let setup = || { + let mut p = ConnectionPair::new(); + p.send( + who_shot_first, + vec![ConnectionClosed::default().into()], + None, + ) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (who_shot_first, h11::State::Closed), + (who_shot_second, h11::State::MustClose) + ]) + ); + } + return p; + }; + + // You can keep putting b"" into a closed connection, and you keep + // getting ConnectionClosed() out: + let mut p = setup(); + assert_eq!( + p.conn + .get_mut(&who_shot_second) + .unwrap() + .next_event() + .unwrap(), + Event::ConnectionClosed(ConnectionClosed::default()) + ); + assert_eq!( + p.conn + .get_mut(&who_shot_second) + .unwrap() + .next_event() + .unwrap(), + Event::ConnectionClosed(ConnectionClosed::default()) + ); + p.conn + .get_mut(&who_shot_second) + .unwrap() + .receive_data(b"") + .unwrap(); + assert_eq!( + p.conn + .get_mut(&who_shot_second) + .unwrap() + .next_event() + .unwrap(), + Event::ConnectionClosed(ConnectionClosed::default()) + ); + // Second party can close... + let mut p = setup(); + p.send( + who_shot_second, + vec![ConnectionClosed::default().into()], + None, + ) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!(connection.get_our_state(), h11::State::Closed); + assert_eq!(connection.get_their_state(), h11::State::Closed); + } + // But trying to receive new data on a closed connection is a + // RuntimeError (not ProtocolError, because the problem here isn't + // violation of HTTP, it's violation of physics) + let mut p = setup(); + assert!(match p + .conn + .get_mut(&who_shot_second) + .unwrap() + .receive_data(b"123") + { + Err(message) => true, + _ => false, + }); + // And receiving new data on a MUST_CLOSE connection is a ProtocolError + let mut p = setup(); + p.conn + .get_mut(&who_shot_first) + .unwrap() + .receive_data(b"GET") + .unwrap(); + assert!(match p + .conn + .get_mut(&who_shot_first) + .unwrap() + .next_event() + .unwrap_err() + { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + } +} + +// def test_close_different_states() -> None: +// req = [ +// Request(method="GET", target="/foo", headers=[("Host", "a")]), +// EndOfMessage(), +// ] +// resp = [ +// Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), +// EndOfMessage(), +// ] + +// # Client before request +// p = ConnectionPair() +// p.send(CLIENT, ConnectionClosed()) +// for conn in p.conns: +// assert conn.states == {CLIENT: CLOSED, SERVER: MUST_CLOSE} + +// # Client after request +// p = ConnectionPair() +// p.send(CLIENT, req) +// p.send(CLIENT, ConnectionClosed()) +// for conn in p.conns: +// assert conn.states == {CLIENT: CLOSED, SERVER: SEND_RESPONSE} + +// # Server after request -> not allowed +// p = ConnectionPair() +// p.send(CLIENT, req) +// with pytest.raises(LocalProtocolError): +// p.conn[SERVER].send(ConnectionClosed()) +// p.conn[CLIENT].receive_data(b"") +// with pytest.raises(RemoteProtocolError): +// p.conn[CLIENT].next_event() + +// # Server after response +// p = ConnectionPair() +// p.send(CLIENT, req) +// p.send(SERVER, resp) +// p.send(SERVER, ConnectionClosed()) +// for conn in p.conns: +// assert conn.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED} + +// # Both after closing (ConnectionClosed() is idempotent) +// p = ConnectionPair() +// p.send(CLIENT, req) +// p.send(SERVER, resp) +// p.send(CLIENT, ConnectionClosed()) +// p.send(SERVER, ConnectionClosed()) +// p.send(CLIENT, ConnectionClosed()) +// p.send(SERVER, ConnectionClosed()) + +// # In the middle of sending -> not allowed +// p = ConnectionPair() +// p.send( +// CLIENT, +// Request( +// method="GET", target="/", headers=[("Host", "a"), ("Content-Length", "10")] +// ), +// ) +// with pytest.raises(LocalProtocolError): +// p.conn[CLIENT].send(ConnectionClosed()) +// p.conn[SERVER].receive_data(b"") +// with pytest.raises(RemoteProtocolError): +// p.conn[SERVER].next_event() + +#[test] +fn test_close_different_states() { + let req: Vec = vec![ + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"a".to_vec())].into(), + b"/foo".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + EndOfMessage::default().into(), + ]; + let resp: Vec = vec![ + Response { + status_code: 200, + headers: vec![(b"transfer-encoding".to_vec(), b"chunked".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + EndOfMessage::default().into(), + ]; + + // Client before request + let mut p = ConnectionPair::new(); + p.send(Role::Client, vec![ConnectionClosed::default().into()], None) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Closed), + (Role::Server, h11::State::MustClose) + ]) + ); + } + + // Client after request + let mut p = ConnectionPair::new(); + p.send(Role::Client, req.clone(), None).unwrap(); + p.send(Role::Client, vec![ConnectionClosed::default().into()], None) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::Closed), + (Role::Server, h11::State::SendResponse) + ]) + ); + } + + // Server after request -> not allowed + let mut p = ConnectionPair::new(); + p.send(Role::Client, req.clone(), None).unwrap(); + assert!(match p + .conn + .get_mut(&Role::Server) + .unwrap() + .send(ConnectionClosed::default().into()) + { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + p.conn + .get_mut(&Role::Client) + .unwrap() + .receive_data(b"") + .unwrap(); + assert!(match p + .conn + .get_mut(&Role::Client) + .unwrap() + .next_event() + .unwrap_err() + { + ProtocolError::RemoteProtocolError(_) => true, + ProtocolError::LocalProtocolError(m) => panic!("{:?}", m), + }); + + // Server after response + let mut p = ConnectionPair::new(); + p.send(Role::Client, req.clone(), None).unwrap(); + p.send(Role::Server, resp.clone(), None).unwrap(); + p.send(Role::Server, vec![ConnectionClosed::default().into()], None) + .unwrap(); + for (_, connection) in &mut p.conn { + assert_eq!( + connection.get_states(), + HashMap::from([ + (Role::Client, h11::State::MustClose), + (Role::Server, h11::State::Closed) + ]) + ); + } + + // Both after closing (ConnectionClosed() is idempotent) + let mut p = ConnectionPair::new(); + p.send(Role::Client, req.clone(), None).unwrap(); + p.send(Role::Server, resp.clone(), None).unwrap(); + p.send(Role::Client, vec![ConnectionClosed::default().into()], None) + .unwrap(); + p.send(Role::Server, vec![ConnectionClosed::default().into()], None) + .unwrap(); + p.send(Role::Client, vec![ConnectionClosed::default().into()], None) + .unwrap(); + p.send(Role::Server, vec![ConnectionClosed::default().into()], None) + .unwrap(); + + // In the middle of sending -> not allowed + let mut p = ConnectionPair::new(); + p.send( + Role::Client, + vec![Request::new( + b"GET".to_vec(), + vec![ + (b"Host".to_vec(), b"a".to_vec()), + (b"Content-Length".to_vec(), b"10".to_vec()), + ] + .into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into()], + None, + ) + .unwrap(); + assert!(match p + .conn + .get_mut(&Role::Client) + .unwrap() + .send(ConnectionClosed::default().into()) + { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + p.conn + .get_mut(&Role::Server) + .unwrap() + .receive_data(b"") + .unwrap(); + assert!(match p + .conn + .get_mut(&Role::Server) + .unwrap() + .next_event() + .unwrap_err() + { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} + +// # Receive several requests and then client shuts down their side of the +// # connection; we can respond to each +// def test_pipelined_close() -> None: +// c = Connection(SERVER) +// # 2 requests then a close +// c.receive_data( +// b"GET /1 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n" +// b"12345" +// b"GET /2 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n" +// b"67890" +// ) +// c.receive_data(b"") +// assert get_all_events(c) == [ +// Request( +// method="GET", +// target="/1", +// headers=[("host", "a.com"), ("content-length", "5")], +// ), +// Data(data=b"12345"), +// EndOfMessage(), +// ] +// assert c.states[CLIENT] is DONE +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// assert c.states[SERVER] is DONE +// c.start_next_cycle() +// assert get_all_events(c) == [ +// Request( +// method="GET", +// target="/2", +// headers=[("host", "a.com"), ("content-length", "5")], +// ), +// Data(data=b"67890"), +// EndOfMessage(), +// ConnectionClosed(), +// ] +// assert c.states == {CLIENT: CLOSED, SERVER: SEND_RESPONSE} +// c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] +// c.send(EndOfMessage()) +// assert c.states == {CLIENT: CLOSED, SERVER: MUST_CLOSE} +// c.send(ConnectionClosed()) +// assert c.states == {CLIENT: CLOSED, SERVER: CLOSED} + +#[test] +fn test_pipelined_close() { + let mut c = Connection::new(Role::Server, None); + // 2 requests then a close + c.receive_data( + &vec![ + b"GET /1 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n".to_vec(), + b"12345".to_vec(), + b"GET /2 HTTP/1.1\r\nHost: a.com\r\nContent-Length: 5\r\n\r\n".to_vec(), + b"67890".to_vec(), + ] + .into_iter() + .flatten() + .collect::>(), + ) + .unwrap(); + c.receive_data(b"").unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/1".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"a.com".to_vec()), + (b"Content-Length".to_vec(), b"5".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + }), + Event::Data(Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::Done); + + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + assert_eq!(c.get_our_state(), h11::State::Done); + + c.start_next_cycle().unwrap(); + assert_eq!( + get_all_events(&mut c).unwrap(), + vec![ + Event::Request(Request { + method: b"GET".to_vec(), + target: b"/2".to_vec(), + headers: vec![ + (b"Host".to_vec(), b"a.com".to_vec()), + (b"Content-Length".to_vec(), b"5".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + }), + Event::Data(Data { + data: b"67890".to_vec(), + chunk_start: false, + chunk_end: false, + }), + Event::EndOfMessage(EndOfMessage::default()), + Event::ConnectionClosed(ConnectionClosed::default()), + ], + ); + assert_eq!(c.get_their_state(), h11::State::Closed); + assert_eq!(c.get_our_state(), h11::State::SendResponse); + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + .unwrap(); + c.send(EndOfMessage::default().into()).unwrap(); + assert_eq!(c.get_their_state(), h11::State::Closed); + assert_eq!(c.get_our_state(), h11::State::MustClose); + c.send(ConnectionClosed::default().into()).unwrap(); + assert_eq!(c.get_their_state(), h11::State::Closed); + assert_eq!(c.get_our_state(), h11::State::Closed); +} + +// def test_errors() -> None: +// # After a receive error, you can't receive +// for role in [CLIENT, SERVER]: +// c = Connection(our_role=role) +// c.receive_data(b"gibberish\r\n\r\n") +// with pytest.raises(RemoteProtocolError): +// c.next_event() +// # Now any attempt to receive continues to raise +// assert c.their_state is ERROR +// assert c.our_state is not ERROR +// print(c._cstate.states) +// with pytest.raises(RemoteProtocolError): +// c.next_event() +// # But we can still yell at the client for sending us gibberish +// if role is SERVER: +// assert ( +// c.send(Response(status_code=400, headers=[])) # type: ignore[arg-type] +// == b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n" +// ) + +// # After an error sending, you can no longer send +// # (This is especially important for things like content-length errors, +// # where there's complex internal state being modified) +// def conn(role: Type[Sentinel]) -> Connection: +// c = Connection(our_role=role) +// if role is SERVER: +// # Put it into the state where it *could* send a response... +// receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") +// assert c.our_state is SEND_RESPONSE +// return c + +// for role in [CLIENT, SERVER]: +// if role is CLIENT: +// # This HTTP/1.0 request won't be detected as bad until after we go +// # through the state machine and hit the writing code +// good = Request(method="GET", target="/", headers=[("Host", "example.com")]) +// bad = Request( +// method="GET", +// target="/", +// headers=[("Host", "example.com")], +// http_version="1.0", +// ) +// elif role is SERVER: +// good = Response(status_code=200, headers=[]) # type: ignore[arg-type,assignment] +// bad = Response(status_code=200, headers=[], http_version="1.0") # type: ignore[arg-type,assignment] +// # Make sure 'good' actually is good +// c = conn(role) +// c.send(good) +// assert c.our_state is not ERROR +// # Do that again, but this time sending 'bad' first +// c = conn(role) +// with pytest.raises(LocalProtocolError): +// c.send(bad) +// assert c.our_state is ERROR +// assert c.their_state is not ERROR +// # Now 'good' is not so good +// with pytest.raises(LocalProtocolError): +// c.send(good) + +// # And check send_failed() too +// c = conn(role) +// c.send_failed() +// assert c.our_state is ERROR +// assert c.their_state is not ERROR +// # This is idempotent +// c.send_failed() +// assert c.our_state is ERROR +// assert c.their_state is not ERROR + +#[test] +fn test_errors() { + // After a receive error, you can't receive + for role in vec![Role::Client, Role::Server] { + let mut c = Connection::new(role, None); + c.receive_data(b"gibberish\r\n\r\n").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + // Now any attempt to receive continues to raise + assert_eq!(c.get_their_state(), h11::State::Error); + assert_ne!(c.get_our_state(), h11::State::Error); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + // But we can still yell at the client for sending us gibberish + if role == Role::Server { + assert_eq!( + c.send( + Response { + status_code: 400, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 400 \r\nconnection: close\r\n\r\n".to_vec() + ); + } + } + + // After an error sending, you can no longer send + // (This is especially important for things like content-length errors, + // where there's complex internal state being modified) + let conn = |role: Role| -> Connection { + let mut c = Connection::new(role, None); + if role == Role::Server { + // Put it into the state where it *could* send a response... + receive_and_get( + &mut c, + &b"GET / HTTP/1.0\r\n\r\n" + .to_vec() + .into_iter() + .collect::>(), + ) + .unwrap(); + assert_eq!(c.get_our_state(), h11::State::SendResponse); + } + return c; + }; + + for role in vec![Role::Client, Role::Server] { + let (good, bad): (Event, Event) = if role == Role::Client { + // This HTTP/1.0 request won't be detected as bad until after we go + // through the state machine and hit the writing code + ( + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"example.com".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"example.com".to_vec())].into(), + b"/".to_vec(), + b"1.0".to_vec(), + ) + .unwrap() + .into(), + ) + } else { + ( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into(), + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.0".to_vec(), + reason: b"".to_vec(), + } + .into(), + ) + }; + // Make sure 'good' actually is good + let mut c = conn(role); + c.send(good.clone()).unwrap(); + assert_ne!(c.get_our_state(), h11::State::Error); + // Do that again, but this time sending 'bad' first + let mut c = conn(role); + assert!(match c.send(bad.clone()) { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + assert_eq!(c.get_our_state(), h11::State::Error); + assert_ne!(c.get_their_state(), h11::State::Error); + // Now 'good' is not so good + assert!(match c.send(good.clone()) { + Err(ProtocolError::LocalProtocolError(_)) => true, + _ => false, + }); + + // And check send_failed() too + let mut c = conn(role); + c.send_failed(); + assert_eq!(c.get_our_state(), h11::State::Error); + assert_ne!(c.get_their_state(), h11::State::Error); + // This is idempotent + c.send_failed(); + assert_eq!(c.get_our_state(), h11::State::Error); + assert_ne!(c.get_their_state(), h11::State::Error); + } +} + +// def test_idle_receive_nothing() -> None: +// # At one point this incorrectly raised an error +// for role in [CLIENT, SERVER]: +// c = Connection(role) +// assert c.next_event() is NEED_DATA + +#[test] +fn test_idle_receive_nothing() { + // At one point this incorrectly raised an error + for role in vec![Role::Client, Role::Server] { + let mut c = Connection::new(role, None); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + } +} + +// def test_connection_drop() -> None: +// c = Connection(SERVER) +// c.receive_data(b"GET /") +// assert c.next_event() is NEED_DATA +// c.receive_data(b"") +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_connection_drop() { + let mut c = Connection::new(Role::Server, None); + c.receive_data(b"GET /").unwrap(); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + c.receive_data(b"").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} + +// def test_408_request_timeout() -> None: +// # Should be able to send this spontaneously as a server without seeing +// # anything from client +// p = ConnectionPair() +// p.send(SERVER, Response(status_code=408, headers=[(b"connection", b"close")])) + +#[test] +fn test_408_request_timeout() { + // Should be able to send this spontaneously as a server without seeing + // anything from client + let mut p = ConnectionPair::new(); + p.send( + Role::Server, + vec![Response { + status_code: 408, + headers: vec![(b"connection".to_vec(), b"close".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into()], + None, + ) + .unwrap(); +} + +// # This used to raise IndexError +// def test_empty_request() -> None: +// c = Connection(SERVER) +// c.receive_data(b"\r\n") +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_empty_request() { + let mut c = Connection::new(Role::Server, None); + c.receive_data(b"\r\n").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} + +// # This used to raise IndexError +// def test_empty_response() -> None: +// c = Connection(CLIENT) +// c.send(Request(method="GET", target="/", headers=[("Host", "a")])) +// c.receive_data(b"\r\n") +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_empty_response() { + let mut c = Connection::new(Role::Client, None); + c.send( + Request::new( + b"GET".to_vec(), + vec![(b"Host".to_vec(), b"a".to_vec())].into(), + b"/".to_vec(), + b"1.1".to_vec(), + ) + .unwrap() + .into(), + ) + .unwrap(); + c.receive_data(b"\r\n").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} + +// @pytest.mark.parametrize( +// "data", +// [ +// b"\x00", +// b"\x20", +// b"\x16\x03\x01\x00\xa5", # Typical start of a TLS Client Hello +// ], +// ) +// def test_early_detection_of_invalid_request(data: bytes) -> None: +// c = Connection(SERVER) +// # Early detection should occur before even receiving a `\r\n` +// c.receive_data(data) +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_early_detection_of_invalid_request() { + let data = vec![ + b"\x00".to_vec(), + b"\x20".to_vec(), + b"\x16\x03\x01\x00\xa5".to_vec(), // Typical start of a TLS Client Hello + ]; + for data in data { + let mut c = Connection::new(Role::Server, None); + // Early detection should occur before even receiving a `\r\n` + c.receive_data(&data).unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + } +} + +// @pytest.mark.parametrize( +// "data", +// [ +// b"\x00", +// b"\x20", +// b"\x16\x03\x03\x00\x31", # Typical start of a TLS Server Hello +// ], +// ) +// def test_early_detection_of_invalid_response(data: bytes) -> None: +// c = Connection(CLIENT) +// # Early detection should occur before even receiving a `\r\n` +// c.receive_data(data) +// with pytest.raises(RemoteProtocolError): +// c.next_event() + +#[test] +fn test_early_detection_of_invalid_response() { + let data = vec![ + b"\x00".to_vec(), + b"\x20".to_vec(), + b"\x16\x03\x03\x00\x31".to_vec(), // Typical start of a TLS Server Hello + ]; + for data in data { + let mut c = Connection::new(Role::Client, None); + // Early detection should occur before even receiving a `\r\n` + c.receive_data(&data).unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + } +} + +// # This used to give different headers for HEAD and GET. +// # The correct way to handle HEAD is to put whatever headers we *would* have +// # put if it were a GET -- even though we know that for HEAD, those headers +// # will be ignored. +// def test_HEAD_framing_headers() -> None: +// def setup(method: bytes, http_version: bytes) -> Connection: +// c = Connection(SERVER) +// c.receive_data( +// method + b" / HTTP/" + http_version + b"\r\n" + b"Host: example.com\r\n\r\n" +// ) +// assert type(c.next_event()) is Request +// assert type(c.next_event()) is EndOfMessage +// return c + +// for method in [b"GET", b"HEAD"]: +// # No Content-Length, HTTP/1.1 peer, should use chunked +// c = setup(method, b"1.1") +// assert ( +// c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] +// b"Transfer-Encoding: chunked\r\n\r\n" +// ) + +// # No Content-Length, HTTP/1.0 peer, frame with connection: close +// c = setup(method, b"1.0") +// assert ( +// c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] +// b"Connection: close\r\n\r\n" +// ) + +// # Content-Length + Transfer-Encoding, TE wins +// c = setup(method, b"1.1") +// assert ( +// c.send( +// Response( +// status_code=200, +// headers=[ +// ("Content-Length", "100"), +// ("Transfer-Encoding", "chunked"), +// ], +// ) +// ) +// == b"HTTP/1.1 200 \r\n" +// b"Transfer-Encoding: chunked\r\n\r\n" +// ) + +#[test] +fn test_head_framing_headers() { + let setup = |method: &[u8], http_version: &[u8]| -> Connection { + let mut c = Connection::new(Role::Server, None); + c.receive_data( + &vec![ + method.to_vec(), + b" / HTTP/".to_vec(), + http_version.to_vec(), + b"\r\n".to_vec(), + b"Host: example.com\r\n\r\n".to_vec(), + ] + .into_iter() + .flatten() + .collect::>(), + ) + .unwrap(); + assert!(match c.next_event().unwrap() { + Event::Request(_) => true, + _ => false, + }); + assert!(match c.next_event().unwrap() { + Event::EndOfMessage(_) => true, + _ => false, + }); + return c; + }; + + for method in vec![b"GET".to_vec(), b"HEAD".to_vec()] { + // No Content-Length, HTTP/1.1 peer, should use chunked + let mut c = setup(&method, &b"1.1".to_vec()); + assert_eq!( + c.send( + Response { + status_code: 200, + headers: vec![].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 200 \r\ntransfer-encoding: chunked\r\n\r\n".to_vec() + ); + + // No Content-Length, HTTP/1.0 peer, frame with connection: close + let mut c = setup(&method, &b"1.0".to_vec()); + assert_eq!( + c.send( + Response { + status_code: 200, + headers: vec![(b"connection".to_vec(), b"close".to_vec())].into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n".to_vec() + ); + + // Content-Length + Transfer-Encoding, TE wins + let mut c = setup(&method, &b"1.1".to_vec()); + assert_eq!( + c.send( + Response { + status_code: 200, + headers: vec![ + (b"Content-Length".to_vec(), b"100".to_vec()), + (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()) + ] + .into(), + http_version: b"1.1".to_vec(), + reason: b"".to_vec(), + } + .into() + ) + .unwrap() + .unwrap(), + b"HTTP/1.1 200 \r\ntransfer-encoding: chunked\r\n\r\n".to_vec() + ); + } +} + +// def test_special_exceptions_for_lost_connection_in_message_body() -> None: +// c = Connection(SERVER) +// c.receive_data( +// b"POST / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 100\r\n\r\n" +// ) +// assert type(c.next_event()) is Request +// assert c.next_event() is NEED_DATA +// c.receive_data(b"12345") +// assert c.next_event() == Data(data=b"12345") +// c.receive_data(b"") +// with pytest.raises(RemoteProtocolError) as excinfo: +// c.next_event() +// assert "received 5 bytes" in str(excinfo.value) +// assert "expected 100" in str(excinfo.value) + +// c = Connection(SERVER) +// c.receive_data( +// b"POST / HTTP/1.1\r\n" +// b"Host: example.com\r\n" +// b"Transfer-Encoding: chunked\r\n\r\n" +// ) +// assert type(c.next_event()) is Request +// assert c.next_event() is NEED_DATA +// c.receive_data(b"8\r\n012345") +// assert c.next_event().data == b"012345" # type: ignore +// c.receive_data(b"") +// with pytest.raises(RemoteProtocolError) as excinfo: +// c.next_event() +// assert "incomplete chunked read" in str(excinfo.value) + +#[test] +fn test_special_exceptions_for_lost_connection_in_message_body() { + let mut c = Connection::new(Role::Server, None); + c.receive_data( + &vec![ + b"POST / HTTP/1.1\r\n".to_vec(), + b"Host: example.com\r\n".to_vec(), + b"Content-Length: 100\r\n\r\n".to_vec(), + ] + .into_iter() + .flatten() + .collect::>(), + ) + .unwrap(); + assert!(match c.next_event().unwrap() { + Event::Request(_) => true, + _ => false, + }); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + c.receive_data(b"12345").unwrap(); + assert_eq!( + c.next_event().unwrap(), + Event::Data(Data { + data: b"12345".to_vec(), + chunk_start: false, + chunk_end: false, + }) + ); + c.receive_data(b"").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); + + let mut c = Connection::new(Role::Server, None); + c.receive_data( + &vec![ + b"POST / HTTP/1.1\r\n".to_vec(), + b"Host: example.com\r\n".to_vec(), + b"Transfer-Encoding: chunked\r\n\r\n".to_vec(), + ] + .into_iter() + .flatten() + .collect::>(), + ) + .unwrap(); + assert!(match c.next_event().unwrap() { + Event::Request(_) => true, + _ => false, + }); + assert_eq!(c.next_event().unwrap(), Event::NeedData {}); + c.receive_data(b"8\r\n012345").unwrap(); + assert_eq!( + match c.next_event().unwrap() { + Event::Data(d) => d.data, + _ => panic!(), + }, + b"012345".to_vec() + ); + c.receive_data(b"").unwrap(); + assert!(match c.next_event().unwrap_err() { + ProtocolError::RemoteProtocolError(_) => true, + _ => false, + }); +} diff --git a/tests/helper.rs b/tests/helper.rs new file mode 100644 index 0000000..97cc68f --- /dev/null +++ b/tests/helper.rs @@ -0,0 +1,201 @@ +// from typing import cast, List, Type, Union, ValuesView + +// from .._connection import Connection, NEED_DATA, PAUSED +// from .._events import ( +// ConnectionClosed, +// Data, +// EndOfMessage, +// Event, +// InformationalResponse, +// Request, +// Response, +// ) +// from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER +// from .._util import Sentinel + +// try: +// from typing import Literal +// except ImportError: +// from typing_extensions import Literal # type: ignore + +// def get_all_events(conn: Connection) -> List[Event]: +// got_events = [] +// while True: +// event = conn.next_event() +// if event in (NEED_DATA, PAUSED): +// break +// event = cast(Event, event) +// got_events.append(event) +// if type(event) is ConnectionClosed: +// break +// return got_events + +// def receive_and_get(conn: Connection, data: bytes) -> List[Event]: +// conn.receive_data(data) +// return get_all_events(conn) + +// # Merges adjacent Data events, converts payloads to bytestrings, and removes +// # chunk boundaries. +// def normalize_data_events(in_events: List[Event]) -> List[Event]: +// out_events: List[Event] = [] +// for event in in_events: +// if type(event) is Data: +// event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False) +// if out_events and type(out_events[-1]) is type(event) is Data: +// out_events[-1] = Data( +// data=out_events[-1].data + event.data, +// chunk_start=out_events[-1].chunk_start, +// chunk_end=out_events[-1].chunk_end, +// ) +// else: +// out_events.append(event) +// return out_events + +// # Given that we want to write tests that push some events through a Connection +// # and check that its state updates appropriately... we might as make a habit +// # of pushing them through two Connections with a fake network link in +// # between. +// class ConnectionPair: +// def __init__(self) -> None: +// self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)} +// self.other = {CLIENT: SERVER, SERVER: CLIENT} + +// @property +// def conns(self) -> ValuesView[Connection]: +// return self.conn.values() + +// # expect="match" if expect=send_events; expect=[...] to say what expected +// def send( +// self, +// role: Type[Sentinel], +// send_events: Union[List[Event], Event], +// expect: Union[List[Event], Event, Literal["match"]] = "match", +// ) -> bytes: +// if not isinstance(send_events, list): +// send_events = [send_events] +// data = b"" +// closed = False +// for send_event in send_events: +// new_data = self.conn[role].send(send_event) +// if new_data is None: +// closed = True +// else: +// data += new_data +// # send uses b"" to mean b"", and None to mean closed +// # receive uses b"" to mean closed, and None to mean "try again" +// # so we have to translate between the two conventions +// if data: +// self.conn[self.other[role]].receive_data(data) +// if closed: +// self.conn[self.other[role]].receive_data(b"") +// got_events = get_all_events(self.conn[self.other[role]]) +// if expect == "match": +// expect = send_events +// if not isinstance(expect, list): +// expect = [expect] +// assert got_events == expect +// return data + +use h11::{Connection, Data, Event, EventType, ProtocolError, Role}; +use std::collections::HashMap; + +pub fn get_all_events(conn: &mut Connection) -> Result, ProtocolError> { + let mut got_events = Vec::new(); + loop { + let event = conn.next_event()?; + let event_type = EventType::from(&event); + if event_type == EventType::NeedData || event_type == EventType::Paused { + break; + } + got_events.push(event); + if event_type == EventType::ConnectionClosed { + break; + } + } + return Ok(got_events); +} + +pub fn receive_and_get(conn: &mut Connection, data: &[u8]) -> Result, ProtocolError> { + conn.receive_data(data).unwrap(); + return get_all_events(conn); +} + +pub fn normalize_data_events(in_events: Vec) -> Vec { + let mut out_events = Vec::new(); + for in_event in in_events { + let event = match in_event { + Event::Data(data) => Event::Data(Data { + data: data.data.clone(), + chunk_start: false, + chunk_end: false, + }), + _ => in_event.clone(), + }; + if !out_events.is_empty() { + let event_type = EventType::from(&event); + let last_event = out_events.last().unwrap(); + let last_event_type = EventType::from(last_event); + if last_event_type == event_type && event_type == EventType::Data { + let l = out_events.len(); + out_events[l - 1] = event; + continue; + } + } + out_events.push(event); + } + return out_events; +} + +pub struct ConnectionPair { + pub conn: HashMap, + pub other: HashMap, +} + +impl ConnectionPair { + pub fn new() -> Self { + Self { + conn: HashMap::from([ + (Role::Client, Connection::new(Role::Client, None)), + (Role::Server, Connection::new(Role::Server, None)), + ]), + other: HashMap::from([(Role::Client, Role::Server), (Role::Server, Role::Client)]), + } + } + + pub fn send( + &mut self, + role: Role, + send_events: Vec, + expect: Option>, + ) -> Result, ProtocolError> { + let mut data = Vec::new(); + let mut closed = false; + for send_event in &send_events { + match self.conn.get_mut(&role).unwrap().send(send_event.clone())? { + Some(new_data) => data.extend(new_data), + None => closed = true, + } + } + if !data.is_empty() { + self.conn + .get_mut(&self.other[&role]) + .unwrap() + .receive_data(&data) + .unwrap(); + } + if closed { + self.conn + .get_mut(&self.other[&role]) + .unwrap() + .receive_data(b"") + .unwrap(); + } + let got_events = get_all_events(self.conn.get_mut(&self.other[&role]).unwrap())?; + match expect { + Some(expect) => assert_eq!(got_events, expect), + None => assert_eq!(got_events, send_events), + }; + + Ok(data) + } +}