From 2b4c0468dec263dc8b5731f51160df8b9a8b5b0b Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Mon, 20 Jan 2025 14:06:21 +0100 Subject: [PATCH] Handle implicit resets at the right time A stream whose ref count reaches zero while open should not immediately decrease the number of active streams, otherwise MAX_CONCURRENT_STREAMS isn't respected anymore. --- src/client.rs | 2 +- src/proto/streams/prioritize.rs | 1 + src/proto/streams/recv.rs | 2 +- src/proto/streams/state.rs | 49 +++++----- tests/h2-tests/tests/stream_states.rs | 127 ++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 24 deletions(-) diff --git a/src/client.rs b/src/client.rs index ffeda6077..4b13e23c5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -365,7 +365,7 @@ where /// /// [module]: index.html pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - ready!(self.inner.poll_pending_open(cx, self.pending.as_ref()))?; + ready!(self.inner.poll_pending_open(cx, dbg!(self.pending.as_ref())))?; self.pending = None; Poll::Ready(Ok(())) } diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 81825f404..08d2f94a3 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -839,6 +839,7 @@ impl Prioritize { }), None => { if let Some(reason) = stream.state.get_scheduled_reset() { + stream.state.did_schedule_reset(); stream.set_reset(reason, Initiator::Library); let frame = frame::Reset::new(stream.id, reason); diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index d8572d00a..6cdfbea48 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -914,7 +914,7 @@ impl Recv { tracing::trace!("enqueue_reset_expiration; {:?}", stream.id); - if counts.can_inc_num_reset_streams() { + if dbg!(counts.can_inc_num_reset_streams()) { counts.inc_num_reset_streams(); self.pending_reset_expired.push(stream); } diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 5256f09cf..9dea81aa4 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -58,9 +58,20 @@ enum Inner { // TODO: these states shouldn't count against concurrency limits: ReservedLocal, ReservedRemote, - Open { local: Peer, remote: Peer }, + Open { + local: Peer, + remote: Peer, + }, HalfClosedLocal(Peer), // TODO: explicitly name this value HalfClosedRemote(Peer), + /// This indicates to the connection that a reset frame must be sent out + /// once the send queue has been flushed. + /// + /// Examples of when this could happen: + /// - User drops all references to a stream, so we want to CANCEL the it. + /// - Header block size was too large, so we want to REFUSE, possibly + /// after sending a 431 response frame. + ScheduledReset(Reason), Closed(Cause), } @@ -75,15 +86,6 @@ enum Peer { enum Cause { EndStream, Error(Error), - - /// This indicates to the connection that a reset frame must be sent out - /// once the send queue has been flushed. - /// - /// Examples of when this could happen: - /// - User drops all references to a stream, so we want to CANCEL the it. - /// - Header block size was too large, so we want to REFUSE, possibly - /// after sending a 431 response frame. - ScheduledLibraryReset(Reason), } impl State { @@ -339,24 +341,29 @@ impl State { /// Set the stream state to a scheduled reset. pub fn set_scheduled_reset(&mut self, reason: Reason) { debug_assert!(!self.is_closed()); - self.inner = Closed(Cause::ScheduledLibraryReset(reason)); + self.inner = ScheduledReset(reason) } pub fn get_scheduled_reset(&self) -> Option { match self.inner { - Closed(Cause::ScheduledLibraryReset(reason)) => Some(reason), + ScheduledReset(reason) => Some(reason), _ => None, } } pub fn is_scheduled_reset(&self) -> bool { - matches!(self.inner, Closed(Cause::ScheduledLibraryReset(..))) + matches!(self.inner, ScheduledReset(_)) + } + + pub fn did_schedule_reset(&mut self) { + debug_assert!(self.is_scheduled_reset()); + self.inner = Closed(Cause::EndStream); } pub fn is_local_error(&self) -> bool { match self.inner { + ScheduledReset(_) => true, Closed(Cause::Error(ref e)) => e.is_local(), - Closed(Cause::ScheduledLibraryReset(..)) => true, _ => false, } } @@ -416,14 +423,14 @@ impl State { pub fn is_recv_closed(&self) -> bool { matches!( self.inner, - Closed(..) | HalfClosedRemote(..) | ReservedLocal + ScheduledReset(_) | Closed(..) | HalfClosedRemote(..) | ReservedLocal ) } pub fn is_send_closed(&self) -> bool { matches!( self.inner, - Closed(..) | HalfClosedLocal(..) | ReservedRemote + ScheduledReset(_) | Closed(..) | HalfClosedLocal(..) | ReservedRemote ) } @@ -434,10 +441,8 @@ impl State { pub fn ensure_recv_open(&self) -> Result { // TODO: Is this correct? match self.inner { + ScheduledReset(reason) => Err(proto::Error::library_go_away(reason)), Closed(Cause::Error(ref e)) => Err(e.clone()), - Closed(Cause::ScheduledLibraryReset(reason)) => { - Err(proto::Error::library_go_away(reason)) - } Closed(Cause::EndStream) | HalfClosedRemote(..) | ReservedLocal => Ok(false), _ => Ok(true), } @@ -446,9 +451,9 @@ impl State { /// Returns a reason if the stream has been reset. pub(super) fn ensure_reason(&self, mode: PollReset) -> Result, crate::Error> { match self.inner { - Closed(Cause::Error(Error::Reset(_, reason, _))) - | Closed(Cause::Error(Error::GoAway(_, reason, _))) - | Closed(Cause::ScheduledLibraryReset(reason)) => Ok(Some(reason)), + ScheduledReset(reason) + | Closed(Cause::Error(Error::Reset(_, reason, _))) + | Closed(Cause::Error(Error::GoAway(_, reason, _))) => Ok(Some(reason)), Closed(Cause::Error(ref e)) => Err(e.clone().into()), Open { local: Streaming, .. diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index facd367e8..04b9d1aa3 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1218,3 +1218,130 @@ async fn reset_new_stream_before_send() { join(srv, client).await; } + +#[tokio::test] +async fn explicit_reset_with_max_concurrent_stream() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let mock = async move { + let settings = srv.assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1)).await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + + srv.recv_frame(frames::reset(1).cancel()).await; + + srv.recv_frame(frames::headers(3).request("POST", "https://www.example.com/").eos()) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, mut stream) = client.send_request(request, false).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + stream.send_reset(Reason::CANCEL); + }; + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, _) = client.send_request(request, true).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + }; + + h2.await.unwrap(); + }; + + join(mock, h2).await; +} + +#[tokio::test] +async fn implicit_cancel_with_max_concurrent_stream() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let mock = async move { + let settings = srv.assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1)).await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + + srv.recv_frame(frames::reset(1).cancel()).await; + + srv.recv_frame(frames::headers(3).request("POST", "https://www.example.com/").eos()) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, stream) = client.send_request(request, false).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + // This implicitly resets the stream with CANCEL. + drop(stream); + }; + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, _) = client.send_request(request, true).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + }; + + h2.await.unwrap(); + }; + + join(mock, h2).await; +}