Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: limit number of CONTINUATION frames allowed #758

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/codec/framed_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub struct FramedRead<T> {

max_header_list_size: usize,

max_continuation_frames: usize,

partial: Option<Partial>,
}

Expand All @@ -41,6 +43,8 @@ struct Partial {

/// Partial header payload
buf: BytesMut,

continuation_frames_count: usize,
}

#[derive(Debug)]
Expand All @@ -51,10 +55,14 @@ enum Continuable {

impl<T> FramedRead<T> {
pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE;
let max_continuation_frames =
calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length());
FramedRead {
inner,
hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
max_header_list_size,
max_continuation_frames,
partial: None,
}
}
Expand All @@ -68,7 +76,6 @@ impl<T> FramedRead<T> {
}

/// Returns the current max frame size setting
#[cfg(feature = "unstable")]
#[inline]
pub fn max_frame_size(&self) -> usize {
self.inner.decoder().max_frame_length()
Expand All @@ -80,13 +87,17 @@ impl<T> FramedRead<T> {
#[inline]
pub fn set_max_frame_size(&mut self, val: usize) {
assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
self.inner.decoder_mut().set_max_frame_length(val)
self.inner.decoder_mut().set_max_frame_length(val);
// Update max CONTINUATION frames too, since its based on this
self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val);
}

/// Update the max header list size setting.
#[inline]
pub fn set_max_header_list_size(&mut self, val: usize) {
self.max_header_list_size = val;
// Update max CONTINUATION frames too, since its based on this
self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size());
}

/// Update the header table size setting.
Expand All @@ -96,12 +107,22 @@ impl<T> FramedRead<T> {
}
}

fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize {
// At least this many frames needed to use max header list size
let min_frames_for_list = (header_max / frame_max).max(1);
// Some padding for imperfectly packed frames
// 25% without floats
let padding = min_frames_for_list >> 2;
min_frames_for_list.saturating_add(padding).max(5)
}

/// Decodes a frame.
///
/// This method is intentionally de-generified and outlined because it is very large.
fn decode_frame(
hpack: &mut hpack::Decoder,
max_header_list_size: usize,
max_continuation_frames: usize,
partial_inout: &mut Option<Partial>,
mut bytes: BytesMut,
) -> Result<Option<Frame>, Error> {
Expand Down Expand Up @@ -169,6 +190,7 @@ fn decode_frame(
*partial_inout = Some(Partial {
frame: Continuable::$frame(frame),
buf: payload,
continuation_frames_count: 0,
});

return Ok(None);
Expand Down Expand Up @@ -273,6 +295,22 @@ fn decode_frame(
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}

// Check for CONTINUATION flood
if is_end_headers {
partial.continuation_frames_count = 0;
} else {
let cnt = partial.continuation_frames_count + 1;
if cnt > max_continuation_frames {
tracing::debug!("too_many_continuations, max = {}", max_continuation_frames);
return Err(Error::library_go_away_data(
Reason::ENHANCE_YOUR_CALM,
"too_many_continuations",
));
} else {
partial.continuation_frames_count = cnt;
}
}

// Extend the buf
if partial.buf.is_empty() {
partial.buf = bytes.split_off(frame::HEADER_LEN);
Expand Down Expand Up @@ -354,9 +392,16 @@ where
ref mut hpack,
max_header_list_size,
ref mut partial,
max_continuation_frames,
..
} = *self;
if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? {
if let Some(frame) = decode_frame(
hpack,
max_header_list_size,
max_continuation_frames,
partial,
bytes,
)? {
tracing::debug!(?frame, "received");
return Poll::Ready(Some(Ok(frame)));
}
Expand Down
49 changes: 49 additions & 0 deletions tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,55 @@ async fn too_big_headers_sends_reset_after_431_if_not_eos() {
join(client, srv).await;
}

#[tokio::test]
async fn too_many_continuation_frames_sends_goaway() {
h2_support::trace_init!();
let (io, mut client) = mock::new();

let client = async move {
let settings = client.assert_server_handshake().await;
assert_frame_eq(settings, frames::settings().max_header_list_size(1024 * 32));

// the mock impl automatically splits into CONTINUATION frames if the
// headers are too big for one frame. So without a max header list size
// set, we'll send a bunch of headers that will eventually get nuked.
client
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.field("a".repeat(10_000), "b".repeat(10_000))
.field("c".repeat(10_000), "d".repeat(10_000))
.field("e".repeat(10_000), "f".repeat(10_000))
.field("g".repeat(10_000), "h".repeat(10_000))
.field("i".repeat(10_000), "j".repeat(10_000))
.field("k".repeat(10_000), "l".repeat(10_000))
.field("m".repeat(10_000), "n".repeat(10_000))
.field("o".repeat(10_000), "p".repeat(10_000))
.field("y".repeat(10_000), "z".repeat(10_000)),
)
.await;
client
.recv_frame(frames::go_away(0).calm().data("too_many_continuations"))
.await;
};

let srv = async move {
let mut srv = server::Builder::new()
// should mean ~3 continuation
.max_header_list_size(1024 * 32)
.handshake::<_, Bytes>(io)
.await
.expect("handshake");

let err = srv.next().await.unwrap().expect_err("server");
assert!(err.is_go_away());
assert!(err.is_library());
assert_eq!(err.reason(), Some(Reason::ENHANCE_YOUR_CALM));
};

join(client, srv).await;
}

#[tokio::test]
async fn pending_accept_recv_illegal_content_length_data() {
h2_support::trace_init!();
Expand Down