Skip to content

Commit

Permalink
Fix opening new streams over max concurrent (hyperium#707)
Browse files Browse the repository at this point in the history
There was a bug where opening streams over the max concurrent streams
was possible if max_concurrent_streams were lowered beyond the current
number of open streams and there were already new streams adding to the
pending_send queue.

There was two mechanisms for streams to end up in that queue.
1. send_headers would push directly onto pending_send when below
   max_concurrent_streams
2. prioritize would pop from pending_open until max_concurrent_streams
   was reached.

For case 1, a settings frame could be received after pushing many
streams onto pending_send and before the socket was ready to write
again. For case 2, the pending_send queue could have Headers frames
queued going into a Not Ready state with the socket, a settings frame
could be received, and then the headers would be written anyway after
the ack.

The fix is therefore also two fold. Fixing case 1 is as simple as
letting Prioritize decide when to transition streams from `pending_open`
to `pending_send` since only it knows the readiness of the socket and
whether the headers can be written immediately. This is slightly
complicated by the fact that previously SendRequest would block when
streams would be added as "pending open". That was addressed by
guessing when to block based on max concurrent streams rather than the
stream state.

The fix for Prioritize was to conservatively pop streams from
pending_open when the socket is immediately available for writing a
headers frame. This required a change to queuing to support pushing on
the front of pending_send to ensure headers frames don't linger in
pending_send.

The alternative to this was adding a check to pending_send whether a new
stream would exceed max concurrent. In that case, headers frames would
need to carefully be reenqueued. This seemed to impose more complexity
to ensure ordering of stream IDs would be maintained.

Closes hyperium#704
Closes hyperium#706 

Co-authored-by: Joe Wilm <[email protected]>
  • Loading branch information
2 people authored and 0xE282B0 committed Jan 11, 2024
1 parent c62cf49 commit ab75d02
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 26 deletions.
6 changes: 4 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,10 @@ where
self.inner
.send_request(request, end_of_stream, self.pending.as_ref())
.map_err(Into::into)
.map(|stream| {
if stream.is_pending_open() {
.map(|(stream, is_full)| {
if stream.is_pending_open() && is_full {
// Only prevent sending another request when the request queue
// is not full.
self.pending = Some(stream.clone_to_opaque());
}

Expand Down
8 changes: 8 additions & 0 deletions src/proto/streams/counts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl Counts {
}
}

/// Returns true when the next opened stream will reach capacity of outbound streams
///
/// The number of client send streams is incremented in prioritize; send_request has to guess if
/// it should wait before allowing another request to be sent.
pub fn next_send_stream_will_reach_capacity(&self) -> bool {
self.max_send_streams <= (self.num_send_streams + 1)
}

/// Returns the current peer
pub fn peer(&self) -> peer::Dyn {
self.peer
Expand Down
18 changes: 12 additions & 6 deletions src/proto/streams/prioritize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,9 @@ impl Prioritize {
tracing::trace!("poll_complete");

loop {
self.schedule_pending_open(store, counts);
if let Some(mut stream) = self.pop_pending_open(store, counts) {
self.pending_send.push_front(&mut stream);
}

match self.pop_frame(buffer, store, max_frame_len, counts) {
Some(frame) => {
Expand Down Expand Up @@ -874,20 +876,24 @@ impl Prioritize {
}
}

fn schedule_pending_open(&mut self, store: &mut Store, counts: &mut Counts) {
fn pop_pending_open<'s>(
&mut self,
store: &'s mut Store,
counts: &mut Counts,
) -> Option<store::Ptr<'s>> {
tracing::trace!("schedule_pending_open");
// check for any pending open streams
while counts.can_inc_num_send_streams() {
if counts.can_inc_num_send_streams() {
if let Some(mut stream) = self.pending_open.pop(store) {
tracing::trace!("schedule_pending_open; stream={:?}", stream.id);

counts.inc_num_send_streams(&mut stream);
self.pending_send.push(&mut stream);
stream.notify_send();
} else {
return;
return Some(stream);
}
}

None
}
}

Expand Down
25 changes: 15 additions & 10 deletions src/proto/streams/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,27 @@ impl Send {
// Update the state
stream.state.send_open(end_stream)?;

if counts.peer().is_local_init(frame.stream_id()) {
// If we're waiting on a PushPromise anyway
// handle potentially queueing the stream at that point
if !stream.is_pending_push {
if counts.can_inc_num_send_streams() {
counts.inc_num_send_streams(stream);
} else {
self.prioritize.queue_open(stream);
}
}
let mut pending_open = false;
if counts.peer().is_local_init(frame.stream_id()) && !stream.is_pending_push {
self.prioritize.queue_open(stream);
pending_open = true;
}

// Queue the frame for sending
//
// This call expects that, since new streams are in the open queue, new
// streams won't be pushed on pending_send.
self.prioritize
.queue_frame(frame.into(), buffer, stream, task);

// Need to notify the connection when pushing onto pending_open since
// queue_frame only notifies for pending_send.
if pending_open {
if let Some(task) = task.take() {
task.wake();
}
}

Ok(())
}

Expand Down
42 changes: 41 additions & 1 deletion src/proto/streams/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ where
///
/// If the stream is already contained by the list, return `false`.
pub fn push(&mut self, stream: &mut store::Ptr) -> bool {
tracing::trace!("Queue::push");
tracing::trace!("Queue::push_back");

if N::is_queued(stream) {
tracing::trace!(" -> already queued");
Expand Down Expand Up @@ -292,6 +292,46 @@ where
true
}

/// Queue the stream
///
/// If the stream is already contained by the list, return `false`.
pub fn push_front(&mut self, stream: &mut store::Ptr) -> bool {
tracing::trace!("Queue::push_front");

if N::is_queued(stream) {
tracing::trace!(" -> already queued");
return false;
}

N::set_queued(stream, true);

// The next pointer shouldn't be set
debug_assert!(N::next(stream).is_none());

// Queue the stream
match self.indices {
Some(ref mut idxs) => {
tracing::trace!(" -> existing entries");

// Update the provided stream to point to the head node
let head_key = stream.resolve(idxs.head).key();
N::set_next(stream, Some(head_key));

// Update the head pointer
idxs.head = stream.key();
}
None => {
tracing::trace!(" -> first entry");
self.indices = Some(store::Indices {
head: stream.key(),
tail: stream.key(),
});
}
}

true
}

pub fn pop<'a, R>(&mut self, store: &'a mut R) -> Option<store::Ptr<'a>>
where
R: Resolve,
Expand Down
14 changes: 9 additions & 5 deletions src/proto/streams/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ where
mut request: Request<()>,
end_of_stream: bool,
pending: Option<&OpaqueStreamRef>,
) -> Result<StreamRef<B>, SendError> {
) -> Result<(StreamRef<B>, bool), SendError> {
use super::stream::ContentLength;
use http::Method;

Expand Down Expand Up @@ -298,10 +298,14 @@ where
// the lock, so it can't.
me.refs += 1;

Ok(StreamRef {
opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream),
send_buffer: self.send_buffer.clone(),
})
let is_full = me.counts.next_send_stream_will_reach_capacity();
Ok((
StreamRef {
opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream),
send_buffer: self.send_buffer.clone(),
},
is_full,
))
}

pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
Expand Down
88 changes: 88 additions & 0 deletions tests/h2-tests/tests/client_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ async fn request_over_max_concurrent_streams_errors() {

// first request is allowed
let (resp1, mut stream1) = client.send_request(request, false).unwrap();
// as long as we let the connection internals tick
client = h2.drive(client.ready()).await.unwrap();

let request = Request::builder()
.method(Method::POST)
Expand Down Expand Up @@ -284,6 +286,90 @@ async fn request_over_max_concurrent_streams_errors() {
join(srv, h2).await;
}

#[tokio::test]
async fn recv_decrement_max_concurrent_streams_when_requests_queued() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();

let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("POST", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;

srv.ping_pong([0; 8]).await;

// limit this server later in life
srv.send_frame(frames::settings().max_concurrent_streams(1))
.await;
srv.recv_frame(frames::settings_ack()).await;
srv.recv_frame(
frames::headers(3)
.request("POST", "https://example.com/")
.eos(),
)
.await;
srv.ping_pong([1; 8]).await;
srv.send_frame(frames::headers(3).response(200).eos()).await;

srv.recv_frame(
frames::headers(5)
.request("POST", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(5).response(200).eos()).await;
};

let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.expect("handshake");
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::builder()
.method(Method::POST)
.uri("https://example.com/")
.body(())
.unwrap();
// first request is allowed
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();

let request = Request::builder()
.method(Method::POST)
.uri("https://example.com/")
.body(())
.unwrap();

// first request is allowed
let (resp1, _) = client.send_request(request, true).unwrap();

let request = Request::builder()
.method(Method::POST)
.uri("https://example.com/")
.body(())
.unwrap();

// second request is put into pending_open
let (resp2, _) = client.send_request(request, true).unwrap();

h2.drive(async move {
resp1.await.expect("req");
})
.await;
join(async move { h2.await.unwrap() }, async move {
resp2.await.unwrap()
})
.await;
};

join(srv, h2).await;
}

#[tokio::test]
async fn send_request_poll_ready_when_connection_error() {
h2_support::trace_init!();
Expand Down Expand Up @@ -336,6 +422,8 @@ async fn send_request_poll_ready_when_connection_error() {

// first request is allowed
let (resp1, _) = client.send_request(request, true).unwrap();
// as long as we let the connection internals tick
client = h2.drive(client.ready()).await.unwrap();

let request = Request::builder()
.method(Method::POST)
Expand Down
4 changes: 2 additions & 2 deletions tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ async fn push_request_against_concurrency() {
.await;
client.recv_frame(frames::data(2, &b""[..]).eos()).await;
client
.recv_frame(frames::headers(1).response(200).eos())
.recv_frame(frames::headers(4).response(200).eos())
.await;
client
.recv_frame(frames::headers(4).response(200).eos())
.recv_frame(frames::headers(1).response(200).eos())
.await;
};

Expand Down

0 comments on commit ab75d02

Please sign in to comment.