diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs index 44fd7ad794b8..890cf2aed347 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs @@ -938,7 +938,9 @@ private Task ProcessSettingsFrameAsync(in ReadOnlySequence payload) if (_clientSettings.MaxFrameSize != previousMaxFrameSize) { // Don't let the client choose an arbitrarily large size, this will be used for response buffers. - _frameWriter.UpdateMaxFrameSize(Math.Min(_clientSettings.MaxFrameSize, _serverSettings.MaxFrameSize)); + // Safe cast, MaxFrameSize is limited to 2^24-1 bytes by the protocol and by Http2PeerSettings. + // Ref: https://datatracker.ietf.org/doc/html/rfc7540#section-4.2 + _frameWriter.UpdateMaxFrameSize((int)Math.Min(_clientSettings.MaxFrameSize, _serverSettings.MaxFrameSize)); } // This difference can be negative. @@ -1829,4 +1831,4 @@ private static class GracefulCloseInitiator public const int Server = 1; public const int Client = 2; } -} \ No newline at end of file +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs index 802c68c4126f..b799ef02797c 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs @@ -29,6 +29,8 @@ internal sealed class Http2FrameWriter /// TODO (https://github.com/dotnet/aspnetcore/issues/51309): eliminate this limit. private const string MaximumFlowControlQueueSizeProperty = "Microsoft.AspNetCore.Server.Kestrel.Http2.MaxConnectionFlowControlQueueSize"; + private const int HeaderBufferSizeMultiplier = 2; + private static readonly int? AppContextMaximumFlowControlQueueSize = GetAppContextMaximumFlowControlQueueSize(); private static int? GetAppContextMaximumFlowControlQueueSize() @@ -71,8 +73,12 @@ internal sealed class Http2FrameWriter // This is only set to true by tests. private readonly bool _scheduleInline; - private uint _maxFrameSize = Http2PeerSettings.MinAllowedMaxFrameSize; + private int _maxFrameSize = Http2PeerSettings.MinAllowedMaxFrameSize; private byte[] _headerEncodingBuffer; + + // Keep track of the high-water mark of _headerEncodingBuffer's size so we don't have to grow + // through intermediate sizes repeatedly. + private int _headersEncodingLargeBufferSize = Http2PeerSettings.MinAllowedMaxFrameSize * HeaderBufferSizeMultiplier; private long _unflushedBytes; private bool _completed; @@ -110,7 +116,6 @@ public Http2FrameWriter( _headerEncodingBuffer = new byte[_maxFrameSize]; _scheduleInline = serviceContext.Scheduler == PipeScheduler.Inline; - _hpackEncoder = new DynamicHPackEncoder(serviceContext.ServerOptions.AllowResponseHeaderCompression); _maximumFlowControlQueueSize = AppContextMaximumFlowControlQueueSize is null @@ -367,12 +372,15 @@ public void UpdateMaxHeaderTableSize(uint maxHeaderTableSize) } } - public void UpdateMaxFrameSize(uint maxFrameSize) + public void UpdateMaxFrameSize(int maxFrameSize) { lock (_writeLock) { if (_maxFrameSize != maxFrameSize) { + // Safe multiply, MaxFrameSize is limited to 2^24-1 bytes by the protocol and by Http2PeerSettings. + // Ref: https://datatracker.ietf.org/doc/html/rfc7540#section-4.2 + _headersEncodingLargeBufferSize = int.Max(_headersEncodingLargeBufferSize, maxFrameSize * HeaderBufferSizeMultiplier); _maxFrameSize = maxFrameSize; _headerEncodingBuffer = new byte[_maxFrameSize]; } @@ -507,11 +515,12 @@ private void WriteResponseHeadersUnsynchronized(int streamId, int statusCode, Ht { try { + // In the case of the headers, there is always a status header to be returned, so BeginEncodeHeaders will not return BufferTooSmall. _headersEnumerator.Initialize(headers); _outgoingFrame.PrepareHeaders(headerFrameFlags, streamId); - var buffer = _headerEncodingBuffer.AsSpan(); - var done = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, buffer, out var payloadLength); - FinishWritingHeadersUnsynchronized(streamId, payloadLength, done); + var writeResult = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, _headerEncodingBuffer, out var payloadLength); + Debug.Assert(writeResult != HeaderWriteResult.BufferTooSmall, "This always writes the status as the first header, and it should never be an over the buffer size."); + FinishWritingHeadersUnsynchronized(streamId, payloadLength, writeResult); } // Any exception from the HPack encoder can leave the dynamic table in a corrupt state. // Since we allow custom header encoders we don't know what type of exceptions to expect. @@ -548,11 +557,11 @@ private ValueTask WriteDataAndTrailersAsync(Http2Stream stream, in try { - _headersEnumerator.Initialize(headers); + // In the case of the trailers, there is no status header to be written, so even the first call to BeginEncodeHeaders can return BufferTooSmall. _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, streamId); - var buffer = _headerEncodingBuffer.AsSpan(); - var done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out var payloadLength); - FinishWritingHeadersUnsynchronized(streamId, payloadLength, done); + _headersEnumerator.Initialize(headers); + var writeResult = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, _headerEncodingBuffer, out var payloadLength); + FinishWritingHeadersUnsynchronized(streamId, payloadLength, writeResult); } // Any exception from the HPack encoder can leave the dynamic table in a corrupt state. // Since we allow custom header encoders we don't know what type of exceptions to expect. @@ -566,32 +575,102 @@ private ValueTask WriteDataAndTrailersAsync(Http2Stream stream, in } } - private void FinishWritingHeadersUnsynchronized(int streamId, int payloadLength, bool done) + private void SplitHeaderAcrossFrames(int streamId, ReadOnlySpan dataToFrame, bool endOfHeaders, bool isFramePrepared) { - var buffer = _headerEncodingBuffer.AsSpan(); - _outgoingFrame.PayloadLength = payloadLength; - if (done) + var shouldPrepareFrame = !isFramePrepared; + while (dataToFrame.Length > 0) { - _outgoingFrame.HeadersFlags |= Http2HeadersFrameFlags.END_HEADERS; - } + if (shouldPrepareFrame) + { + _outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + } - WriteHeaderUnsynchronized(); - _outputWriter.Write(buffer.Slice(0, payloadLength)); + // Should prepare continuation frames. + shouldPrepareFrame = true; + var currentSize = Math.Min(dataToFrame.Length, _maxFrameSize); + _outgoingFrame.PayloadLength = currentSize; + if (endOfHeaders && dataToFrame.Length == currentSize) + { + _outgoingFrame.HeadersFlags |= Http2HeadersFrameFlags.END_HEADERS; + } - while (!done) - { - _outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + WriteHeaderUnsynchronized(); + _outputWriter.Write(dataToFrame[..currentSize]); + dataToFrame = dataToFrame.Slice(currentSize); + } + } - done = HPackHeaderWriter.ContinueEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength); + private void FinishWritingHeadersUnsynchronized(int streamId, int payloadLength, HeaderWriteResult writeResult) + { + Debug.Assert(payloadLength <= _maxFrameSize, "The initial payload lengths is written to _headerEncodingBuffer with size of _maxFrameSize"); + byte[]? largeHeaderBuffer = null; + Span buffer; + if (writeResult == HeaderWriteResult.Done) + { + // Fast path, only a single HEADER frame. _outgoingFrame.PayloadLength = payloadLength; - - if (done) + _outgoingFrame.HeadersFlags |= Http2HeadersFrameFlags.END_HEADERS; + WriteHeaderUnsynchronized(); + _outputWriter.Write(_headerEncodingBuffer.AsSpan(0, payloadLength)); + return; + } + else if (writeResult == HeaderWriteResult.MoreHeaders) + { + _outgoingFrame.PayloadLength = payloadLength; + WriteHeaderUnsynchronized(); + _outputWriter.Write(_headerEncodingBuffer.AsSpan(0, payloadLength)); + } + else + { + // This may happen in case of the TRAILERS after the initial encode operation. + // The _maxFrameSize sized _headerEncodingBuffer was too small. + while (writeResult == HeaderWriteResult.BufferTooSmall) + { + Debug.Assert(payloadLength == 0, "Payload written even though buffer is too small"); + largeHeaderBuffer = ArrayPool.Shared.Rent(_headersEncodingLargeBufferSize); + buffer = largeHeaderBuffer.AsSpan(0, _headersEncodingLargeBufferSize); + writeResult = HPackHeaderWriter.RetryBeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength); + if (writeResult != HeaderWriteResult.BufferTooSmall) + { + SplitHeaderAcrossFrames(streamId, buffer[..payloadLength], endOfHeaders: writeResult == HeaderWriteResult.Done, isFramePrepared: true); + } + else + { + _headersEncodingLargeBufferSize = checked(_headersEncodingLargeBufferSize * HeaderBufferSizeMultiplier); + } + ArrayPool.Shared.Return(largeHeaderBuffer); + largeHeaderBuffer = null; + } + if (writeResult == HeaderWriteResult.Done) { - _outgoingFrame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; + return; } + } - WriteHeaderUnsynchronized(); - _outputWriter.Write(buffer.Slice(0, payloadLength)); + // HEADERS and zero or more CONTINUATIONS sent - all subsequent frames are (unprepared) CONTINUATIONs + buffer = _headerEncodingBuffer; + while (writeResult != HeaderWriteResult.Done) + { + writeResult = HPackHeaderWriter.ContinueEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength); + if (writeResult == HeaderWriteResult.BufferTooSmall) + { + if (largeHeaderBuffer != null) + { + ArrayPool.Shared.Return(largeHeaderBuffer); + _headersEncodingLargeBufferSize = checked(_headersEncodingLargeBufferSize * HeaderBufferSizeMultiplier); + } + largeHeaderBuffer = ArrayPool.Shared.Rent(_headersEncodingLargeBufferSize); + buffer = largeHeaderBuffer.AsSpan(0, _headersEncodingLargeBufferSize); + } + else + { + // In case of Done or MoreHeaders: write to output. + SplitHeaderAcrossFrames(streamId, buffer[..payloadLength], endOfHeaders: writeResult == HeaderWriteResult.Done, isFramePrepared: false); + } + } + if (largeHeaderBuffer != null) + { + ArrayPool.Shared.Return(largeHeaderBuffer); } } @@ -1023,4 +1102,4 @@ private void EnqueueWaitingForMoreConnectionWindow(Http2OutputProducer producer) _http2Connection.Abort(new ConnectionAbortedException("HTTP/2 connection exceeded the outgoing flow control maximum queue size.")); } } -} \ No newline at end of file +} diff --git a/src/Servers/Kestrel/Core/test/HPackHeaderWriterTests.cs b/src/Servers/Kestrel/Core/test/HPackHeaderWriterTests.cs deleted file mode 100644 index cc6d85be4950..000000000000 --- a/src/Servers/Kestrel/Core/test/HPackHeaderWriterTests.cs +++ /dev/null @@ -1,199 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -//using System; -//using System.Collections.Generic; -//using System.Linq; -//using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; -//using Microsoft.Extensions.Primitives; -//using Xunit; - -//namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests -//{ -// public class HPackHeaderWriterTests -// { -// public static TheoryData[], byte[], int?> SinglePayloadData -// { -// get -// { -// var data = new TheoryData[], byte[], int?>(); - -// // Lowercase header name letters only -// data.Add( -// new[] -// { -// new KeyValuePair("CustomHeader", "CustomValue"), -// }, -// new byte[] -// { -// // 0 12 c u s t o m -// 0x00, 0x0c, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, -// // h e a d e r 11 C -// 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x0b, 0x43, -// // u s t o m V a l -// 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x56, 0x61, 0x6c, -// // u e -// 0x75, 0x65 -// }, -// null); -// // Lowercase header name letters only -// data.Add( -// new[] -// { -// new KeyValuePair("CustomHeader!#$%&'*+-.^_`|~", "CustomValue"), -// }, -// new byte[] -// { -// // 0 27 c u s t o m -// 0x00, 0x1b, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, -// // h e a d e r ! # -// 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x21, 0x23, -// // $ % & ' * + - . -// 0x24, 0x25, 0x26, 0x27, 0x2a, 0x2b, 0x2d, 0x2e, -// // ^ _ ` | ~ 11 C u -// 0x5e, 0x5f, 0x60, 0x7c, 0x7e, 0x0b, 0x43, 0x75, -// // s t o m V a l u -// 0x73, 0x74, 0x6f, 0x6d, 0x56, 0x61, 0x6c, 0x75, -// // e -// 0x65 -// }, -// null); -// // Single Payload -// data.Add( -// new[] -// { -// new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), -// new KeyValuePair("content-type", "text/html; charset=utf-8"), -// new KeyValuePair("server", "Kestrel") -// }, -// new byte[] -// { -// 0x88, 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, -// 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, -// 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, -// 0x20, 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, -// 0x30, 0x20, 0x47, 0x4d, 0x54, 0x00, 0x0c, 0x63, -// 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x74, -// 0x79, 0x70, 0x65, 0x18, 0x74, 0x65, 0x78, 0x74, -// 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x3b, 0x20, 0x63, -// 0x68, 0x61, 0x72, 0x73, 0x65, 0x74, 0x3d, 0x75, -// 0x74, 0x66, 0x2d, 0x38, 0x00, 0x06, 0x73, 0x65, -// 0x72, 0x76, 0x65, 0x72, 0x07, 0x4b, 0x65, 0x73, -// 0x74, 0x72, 0x65, 0x6c -// }, -// 200); - -// return data; -// } -// } - -// [Theory] -// [MemberData(nameof(SinglePayloadData))] -// public void EncodesHeadersInSinglePayloadWhenSpaceAvailable(KeyValuePair[] headers, byte[] expectedPayload, int? statusCode) -// { -// var payload = new byte[1024]; -// var length = 0; -// if (statusCode.HasValue) -// { -// Assert.True(HPackHeaderWriter.BeginEncodeHeaders(statusCode.Value, GetHeadersEnumerator(headers), payload, out length)); -// } -// else -// { -// Assert.True(HPackHeaderWriter.BeginEncodeHeaders(GetHeadersEnumerator(headers), payload, out length)); -// } -// Assert.Equal(expectedPayload.Length, length); - -// for (var i = 0; i < length; i++) -// { -// Assert.True(expectedPayload[i] == payload[i], $"{expectedPayload[i]} != {payload[i]} at {i} (len {length})"); -// } - -// Assert.Equal(expectedPayload, new ArraySegment(payload, 0, length)); -// } - -// [Theory] -// [InlineData(true)] -// [InlineData(false)] -// public void EncodesHeadersInMultiplePayloadsWhenSpaceNotAvailable(bool exactSize) -// { -// var statusCode = 200; -// var headers = new[] -// { -// new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), -// new KeyValuePair("content-type", "text/html; charset=utf-8"), -// new KeyValuePair("server", "Kestrel") -// }; - -// var expectedStatusCodePayload = new byte[] -// { -// 0x88 -// }; - -// var expectedDateHeaderPayload = new byte[] -// { -// 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, -// 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, 0x4a, -// 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, 0x20, -// 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, 0x30, -// 0x20, 0x47, 0x4d, 0x54 -// }; - -// var expectedContentTypeHeaderPayload = new byte[] -// { -// 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, -// 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x18, 0x74, -// 0x65, 0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, -// 0x3b, 0x20, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, -// 0x74, 0x3d, 0x75, 0x74, 0x66, 0x2d, 0x38 -// }; - -// var expectedServerHeaderPayload = new byte[] -// { -// 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, -// 0x07, 0x4b, 0x65, 0x73, 0x74, 0x72, 0x65, 0x6c -// }; - -// Span payload = new byte[1024]; -// var offset = 0; -// var headerEnumerator = GetHeadersEnumerator(headers); - -// // When !exactSize, slices are one byte short of fitting the next header -// var sliceLength = expectedStatusCodePayload.Length + (exactSize ? 0 : expectedDateHeaderPayload.Length - 1); -// Assert.False(HPackHeaderWriter.BeginEncodeHeaders(statusCode, headerEnumerator, payload.Slice(offset, sliceLength), out var length)); -// Assert.Equal(expectedStatusCodePayload.Length, length); -// Assert.Equal(expectedStatusCodePayload, payload.Slice(0, length).ToArray()); - -// offset += length; - -// sliceLength = expectedDateHeaderPayload.Length + (exactSize ? 0 : expectedContentTypeHeaderPayload.Length - 1); -// Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(headerEnumerator, payload.Slice(offset, sliceLength), out length)); -// Assert.Equal(expectedDateHeaderPayload.Length, length); -// Assert.Equal(expectedDateHeaderPayload, payload.Slice(offset, length).ToArray()); - -// offset += length; - -// sliceLength = expectedContentTypeHeaderPayload.Length + (exactSize ? 0 : expectedServerHeaderPayload.Length - 1); -// Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(headerEnumerator, payload.Slice(offset, sliceLength), out length)); -// Assert.Equal(expectedContentTypeHeaderPayload.Length, length); -// Assert.Equal(expectedContentTypeHeaderPayload, payload.Slice(offset, length).ToArray()); - -// offset += length; - -// sliceLength = expectedServerHeaderPayload.Length; -// Assert.True(HPackHeaderWriter.ContinueEncodeHeaders(headerEnumerator, payload.Slice(offset, sliceLength), out length)); -// Assert.Equal(expectedServerHeaderPayload.Length, length); -// Assert.Equal(expectedServerHeaderPayload, payload.Slice(offset, length).ToArray()); -// } - -// private static Http2HeadersEnumerator GetHeadersEnumerator(IEnumerable> headers) -// { -// var groupedHeaders = headers -// .GroupBy(k => k.Key) -// .ToDictionary(g => g.Key, g => new StringValues(g.Select(gg => gg.Value).ToArray())); - -// var enumerator = new Http2HeadersEnumerator(); -// enumerator.Initialize(groupedHeaders); -// return enumerator; -// } -// } -//} diff --git a/src/Servers/Kestrel/Core/test/Http2/Http2FrameWriterTests.cs b/src/Servers/Kestrel/Core/test/Http2/Http2FrameWriterTests.cs index cafd8a98a0f2..ea2ec9ba5609 100644 --- a/src/Servers/Kestrel/Core/test/Http2/Http2FrameWriterTests.cs +++ b/src/Servers/Kestrel/Core/test/Http2/Http2FrameWriterTests.cs @@ -1,17 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.IO.Pipelines; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; using Moq; -using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; @@ -92,6 +86,13 @@ public async Task WriteHeader_UnsetsReservedBit() Assert.Equal(new byte[] { 0x00, 0x00, 0x00, 0x00 }, payload.Skip(5).Take(4).ToArray()); } + + [Fact] + public void UpdateMaxFrameSize_To_ProtocolMaximum() + { + var sut = CreateFrameWriter(new Pipe()); + sut.UpdateMaxFrameSize((int)Math.Pow(2, 24) - 1); + } } public static class PipeReaderExtensions diff --git a/src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs b/src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs index e65d33eaf0fd..2edf4628129e 100644 --- a/src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs +++ b/src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs @@ -1,19 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Linq; using System.Net.Http.HPack; using System.Text; -using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; using Microsoft.Extensions.Primitives; -using Microsoft.Net.Http.Headers; - -using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; @@ -29,7 +22,7 @@ public void BeginEncodeHeaders_Status302_NewIndexValue() enumerator.Initialize(headers); var hpackEncoder = new DynamicHPackEncoder(); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); var result = buffer.Slice(0, length).ToArray(); var hex = BitConverter.ToString(result); @@ -52,7 +45,7 @@ public void BeginEncodeHeaders_CacheControlPrivate_NewIndexValue() enumerator.Initialize(headers); var hpackEncoder = new DynamicHPackEncoder(); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); var result = buffer.Slice(5, length - 5).ToArray(); var hex = BitConverter.ToString(result); @@ -67,7 +60,6 @@ public void BeginEncodeHeaders_CacheControlPrivate_NewIndexValue() public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit() { // Test follows example https://tools.ietf.org/html/rfc7541#appendix-C.5 - Span buffer = new byte[1024 * 16]; var headers = (IHeaderDictionary)new HttpResponseHeaders(); @@ -81,7 +73,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit() // First response enumerator.Initialize(headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); var result = buffer.Slice(0, length).ToArray(); var hex = BitConverter.ToString(result); @@ -123,7 +115,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit() // Second response enumerator.Initialize(headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length)); result = buffer.Slice(0, length).ToArray(); hex = BitConverter.ToString(result); @@ -164,7 +156,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit() headers.SetCookie = "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"; enumerator.Initialize(headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length)); result = buffer.Slice(0, length).ToArray(); hex = BitConverter.ToString(result); @@ -225,7 +217,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction // First response enumerator.Initialize((HttpResponseHeaders)headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length)); var result = buffer.Slice(0, length).ToArray(); var hex = BitConverter.ToString(result); @@ -267,7 +259,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction // Second response enumerator.Initialize(headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length)); result = buffer.Slice(0, length).ToArray(); hex = BitConverter.ToString(result); @@ -308,7 +300,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction headers.SetCookie = "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"; enumerator.Initialize(headers); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length)); result = buffer.Slice(0, length).ToArray(); hex = BitConverter.ToString(result); @@ -366,7 +358,7 @@ public void BeginEncodeHeaders_ExcludedHeaders_NotAddedToTable(string headerName enumerator.Initialize(headers); var hpackEncoder = new DynamicHPackEncoder(maxHeaderTableSize: Http2PeerSettings.DefaultHeaderTableSize); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out _)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out _)); if (neverIndex) { @@ -392,7 +384,7 @@ public void BeginEncodeHeaders_HeaderExceedHeaderTableSize_NoIndexAndNoHeaderEnt enumerator.Initialize(headers); var hpackEncoder = new DynamicHPackEncoder(); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length)); Assert.Empty(GetHeaderEntries(hpackEncoder)); } @@ -477,16 +469,15 @@ public void BeginEncodeHeaders_HeaderExceedHeaderTableSize_NoIndexAndNoHeaderEnt public void EncodesHeadersInSinglePayloadWhenSpaceAvailable(KeyValuePair[] headers, byte[] expectedPayload, int? statusCode) { var hpackEncoder = new DynamicHPackEncoder(); - var payload = new byte[1024]; var length = 0; if (statusCode.HasValue) { - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(statusCode.Value, hpackEncoder, GetHeadersEnumerator(headers), payload, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(statusCode.Value, hpackEncoder, GetHeadersEnumerator(headers), payload, out length)); } else { - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, GetHeadersEnumerator(headers), payload, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, GetHeadersEnumerator(headers), payload, out length)); } Assert.Equal(expectedPayload.Length, length); @@ -548,28 +539,28 @@ public void EncodesHeadersInMultiplePayloadsWhenSpaceNotAvailable(bool exactSize // When !exactSize, slices are one byte short of fitting the next header var sliceLength = expectedStatusCodePayload.Length + (exactSize ? 0 : expectedDateHeaderPayload.Length - 1); - Assert.False(HPackHeaderWriter.BeginEncodeHeaders(statusCode, hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out var length)); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(statusCode, hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out var length)); Assert.Equal(expectedStatusCodePayload.Length, length); Assert.Equal(expectedStatusCodePayload, payload.Slice(0, length).ToArray()); offset += length; sliceLength = expectedDateHeaderPayload.Length + (exactSize ? 0 : expectedContentTypeHeaderPayload.Length - 1); - Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); Assert.Equal(expectedDateHeaderPayload.Length, length); Assert.Equal(expectedDateHeaderPayload, payload.Slice(offset, length).ToArray()); offset += length; sliceLength = expectedContentTypeHeaderPayload.Length + (exactSize ? 0 : expectedServerHeaderPayload.Length - 1); - Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); Assert.Equal(expectedContentTypeHeaderPayload.Length, length); Assert.Equal(expectedContentTypeHeaderPayload, payload.Slice(offset, length).ToArray()); offset += length; sliceLength = expectedServerHeaderPayload.Length; - Assert.True(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length)); Assert.Equal(expectedServerHeaderPayload.Length, length); Assert.Equal(expectedServerHeaderPayload, payload.Slice(offset, length).ToArray()); } @@ -586,7 +577,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeUpdated_SizeUpdateInHeaders() // First request enumerator.Initialize(new Dictionary()); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length)); Assert.Equal(2, length); @@ -600,11 +591,72 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeUpdated_SizeUpdateInHeaders() // Second request enumerator.Initialize(new Dictionary()); - Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out length)); + Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out length)); + + Assert.Equal(0, length); + } + + [Fact] + public void WithStatusCode_TooLargeHeader_ReturnsMoreHeaders() + { + Span buffer = new byte[1024 * 16]; + + IHeaderDictionary headers = new HttpResponseHeaders(); + headers.Cookie = new string('a', buffer.Length + 1); + var enumerator = new Http2HeadersEnumerator(); + enumerator.Initialize(headers); + + var hpackEncoder = new DynamicHPackEncoder(); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(1, length); + } + + [Fact] + public void NoStatusCodeLargeHeader_ReturnsOversized() + { + Span buffer = new byte[1024 * 16]; + + IHeaderDictionary headers = new HttpResponseHeaders(); + headers.Cookie = new string('a', buffer.Length + 1); + var enumerator = new Http2HeadersEnumerator(); + enumerator.Initialize(headers); + var hpackEncoder = new DynamicHPackEncoder(); + Assert.Equal(HeaderWriteResult.BufferTooSmall, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length)); Assert.Equal(0, length); } + [Fact] + public void WithStatusCode_JustFittingHeaderNoSpace_ReturnsMoreHeaders() + { + Span buffer = new byte[1024 * 16]; + + IHeaderDictionary headers = new HttpResponseHeaders(); + headers.Cookie = new string('a', buffer.Length - 1); + var enumerator = new Http2HeadersEnumerator(); + enumerator.Initialize(headers); + + var hpackEncoder = new DynamicHPackEncoder(); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(1, length); + } + + [Fact] + public void NoStatusCode_JustFittingHeaderNoSpace_ReturnsMoreHeaders() + { + Span buffer = new byte[1024 * 16]; + + IHeaderDictionary headers = new HttpResponseHeaders(); + headers.Accept = "application/json;"; + headers.Cookie = new string('a', buffer.Length - 1); + var enumerator = new Http2HeadersEnumerator(); + enumerator.Initialize(headers); + + var hpackEncoder = new DynamicHPackEncoder(); + Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length)); + Assert.Equal(26, length); + } + private static Http2HeadersEnumerator GetHeadersEnumerator(IEnumerable> headers) { var groupedHeaders = headers diff --git a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/HPackHeaderWriterBenchmark.cs b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/HPackHeaderWriterBenchmark.cs index cff79ef3dd96..cff2aa9bed9d 100644 --- a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/HPackHeaderWriterBenchmark.cs +++ b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/HPackHeaderWriterBenchmark.cs @@ -1,16 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Linq; using System.Net.Http.HPack; using System.Text; -using System.Threading.Tasks; using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; -using Microsoft.Net.Http.Headers; namespace Microsoft.AspNetCore.Server.Kestrel.Microbenchmarks; diff --git a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs index 9fe40252d54f..03589ed07b77 100644 --- a/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs +++ b/src/Servers/Kestrel/perf/Microbenchmarks/Http2/Http2FrameWriterBenchmark.cs @@ -44,11 +44,26 @@ public void GlobalSetup() "TestConnectionId", _memoryPool, serviceContext); + } + + private int _largeHeaderSize; - _responseHeaders = new HttpResponseHeaders(); - var headers = (IHeaderDictionary)_responseHeaders; - headers.ContentType = "application/json"; - headers.ContentLength = 1024; + [Params(0, 10, 20)] + public int LargeHeaderSize + { + get => _largeHeaderSize; + set + { + _largeHeaderSize = value; + _responseHeaders = new HttpResponseHeaders(); + var headers = (IHeaderDictionary)_responseHeaders; + headers.ContentType = "application/json"; + headers.ContentLength = 1024; + if (value > 0) + { + headers.Add("my", new string('a', value * 1024)); + } + } } [Benchmark] diff --git a/src/Servers/Kestrel/shared/HPackHeaderWriter.cs b/src/Servers/Kestrel/shared/HPackHeaderWriter.cs index 22cf2256cd53..454346e86c70 100644 --- a/src/Servers/Kestrel/shared/HPackHeaderWriter.cs +++ b/src/Servers/Kestrel/shared/HPackHeaderWriter.cs @@ -10,6 +10,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; #endif +internal enum HeaderWriteResult : byte +{ + // Not all headers written. + MoreHeaders = 0, + + // All headers written. + Done = 1, + + // Oversized header for the given buffer. + BufferTooSmall = 2, +} + // This file is used by Kestrel to write response headers and tests to write request headers. // To avoid adding test code to Kestrel this file is shared. Test specifc code is excluded from Kestrel by ifdefs. internal static class HPackHeaderWriter @@ -17,7 +29,7 @@ internal static class HPackHeaderWriter /// /// Begin encoding headers in the first HEADERS frame. /// - public static bool BeginEncodeHeaders(int statusCode, DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) + public static HeaderWriteResult BeginEncodeHeaders(int statusCode, DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) { length = 0; @@ -35,12 +47,12 @@ public static bool BeginEncodeHeaders(int statusCode, DynamicHPackEncoder hpackE if (!headersEnumerator.MoveNext()) { - return true; + return HeaderWriteResult.Done; } - // We're ok with not throwing if no headers were encoded because we've already encoded the status. + // Since we've already encoded the status, we know we didn't start with an empty buffer. We don't need to increase it immediately because // There is a small chance that the header will encode if there is no other content in the next HEADERS frame. - var done = EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer.Slice(length), throwIfNoneEncoded: false, out var headersLength); + var done = EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer.Slice(length), canRequestLargerBuffer: false, out var headersLength); length += headersLength; return done; } @@ -48,7 +60,19 @@ public static bool BeginEncodeHeaders(int statusCode, DynamicHPackEncoder hpackE /// /// Begin encoding headers in the first HEADERS frame. /// - public static bool BeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) + public static HeaderWriteResult BeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) => + BeginEncodeHeaders(hpackEncoder, headersEnumerator, buffer, iterateBeforeFirstElement: true, out length); + + /// + /// Begin encoding headers in the first HEADERS frame without stepping the iterator. + /// + public static HeaderWriteResult RetryBeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) => + BeginEncodeHeaders(hpackEncoder, headersEnumerator, buffer, iterateBeforeFirstElement: false, out length); + + /// + /// Begin encoding headers in the first HEADERS frame. + /// + private static HeaderWriteResult BeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, bool iterateBeforeFirstElement, out int length) { length = 0; @@ -58,12 +82,12 @@ public static bool BeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2Hea } length += sizeUpdateLength; - if (!headersEnumerator.MoveNext()) + if (iterateBeforeFirstElement && !headersEnumerator.MoveNext()) { - return true; + return HeaderWriteResult.Done; } - var done = EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer.Slice(length), throwIfNoneEncoded: true, out var headersLength); + var done = EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer.Slice(length), canRequestLargerBuffer: true, out var headersLength); length += headersLength; return done; } @@ -71,9 +95,9 @@ public static bool BeginEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2Hea /// /// Continue encoding headers in the next HEADERS frame. The enumerator should already have a current value. /// - public static bool ContinueEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) + public static HeaderWriteResult ContinueEncodeHeaders(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, out int length) { - return EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer, throwIfNoneEncoded: true, out length); + return EncodeHeadersCore(hpackEncoder, headersEnumerator, buffer, canRequestLargerBuffer: true, out length); } private static bool EncodeStatusHeader(int statusCode, DynamicHPackEncoder hpackEncoder, Span buffer, out int length) @@ -91,7 +115,7 @@ private static bool EncodeStatusHeader(int statusCode, DynamicHPackEncoder hpack } } - private static bool EncodeHeadersCore(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, bool throwIfNoneEncoded, out int length) + private static HeaderWriteResult EncodeHeadersCore(DynamicHPackEncoder hpackEncoder, Http2HeadersEnumerator headersEnumerator, Span buffer, bool canRequestLargerBuffer, out int length) { var currentLength = 0; do @@ -115,22 +139,21 @@ private static bool EncodeHeadersCore(DynamicHPackEncoder hpackEncoder, Http2Hea out var headerLength)) { // If the header wasn't written, and no headers have been written, then the header is too large. - // Throw an error to avoid an infinite loop of attempting to write large header. - if (currentLength == 0 && throwIfNoneEncoded) + // Request for a larger buffer to write large header. + if (currentLength == 0 && canRequestLargerBuffer) { - throw new HPackEncodingException(SR.net_http_hpack_encode_failure); + length = 0; + return HeaderWriteResult.BufferTooSmall; } - length = currentLength; - return false; + return HeaderWriteResult.MoreHeaders; } currentLength += headerLength; } while (headersEnumerator.MoveNext()); - length = currentLength; - return true; + return HeaderWriteResult.Done; } private static HeaderEncodingHint ResolveHeaderEncodingHint(int staticTableId, string name) diff --git a/src/Servers/Kestrel/shared/test/PipeWriterHttp2FrameExtensions.cs b/src/Servers/Kestrel/shared/test/PipeWriterHttp2FrameExtensions.cs index 84ead387bc89..4480a06271a4 100644 --- a/src/Servers/Kestrel/shared/test/PipeWriterHttp2FrameExtensions.cs +++ b/src/Servers/Kestrel/shared/test/PipeWriterHttp2FrameExtensions.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; using Http2HeadersEnumerator = Microsoft.AspNetCore.Server.Kestrel.Core.Tests.Http2HeadersEnumerator; using HPackHeaderWriter = Microsoft.AspNetCore.Server.Kestrel.Core.Tests.HPackHeaderWriter; +using HeaderWriteResult = Microsoft.AspNetCore.Server.Kestrel.Core.Tests.HeaderWriteResult; namespace Microsoft.AspNetCore.InternalTesting; @@ -36,7 +37,7 @@ public static void WriteStartStream(this PipeWriter writer, int streamId, Dynami var done = HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, headers, buffer, out var length); frame.PayloadLength = length; - if (done) + if (done == HeaderWriteResult.Done) { frame.HeadersFlags = Http2HeadersFrameFlags.END_HEADERS; } @@ -49,14 +50,14 @@ public static void WriteStartStream(this PipeWriter writer, int streamId, Dynami Http2FrameWriter.WriteHeader(frame, writer); writer.Write(buffer.Slice(0, length)); - while (!done) + while (done != HeaderWriteResult.Done) { frame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); done = HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headers, buffer, out length); frame.PayloadLength = length; - if (done) + if (done == HeaderWriteResult.Done) { frame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 1e5fb54fa0de..8628af4117d9 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -1,27 +1,18 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.Collections.Generic; using System.Globalization; -using System.IO; -using System.Linq; -using System.Net.Http; -using System.Net.Http.HPack; using System.Runtime.ExceptionServices; using System.Text; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; -using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.Logging; using Microsoft.Net.Http.Headers; -using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; @@ -2593,17 +2584,18 @@ await ExpectAsync(Http2FrameType.DATA, } [Fact] - public async Task ResponseTrailers_TooLong_Throws() + public async Task ResponseTrailers_SingleLong_SplitsTrailersToContinuationFrames() { + var trailerValue = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); await InitializeConnectionAsync(async context => { await context.Response.WriteAsync("Hello World"); - context.Response.AppendTrailer("too_long", new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize)); + context.Response.AppendTrailer("too_long", trailerValue); }); await StartStreamAsync(1, _browserRequestHeaders, endStream: true); - var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + await ExpectAsync(Http2FrameType.HEADERS, withLength: 32, withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, withStreamId: 1); @@ -2613,18 +2605,361 @@ await ExpectAsync(Http2FrameType.DATA, withFlags: (byte)Http2DataFrameFlags.NONE, withStreamId: 1); - var goAway = await ExpectAsync(Http2FrameType.GOAWAY, - withLength: 8, + var trailerFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + var trailierContinuation = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 13, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false).DefaultTimeout(); + + var buffer = new byte[trailerFrame.PayloadLength + trailierContinuation.PayloadLength]; + trailerFrame.PayloadSequence.CopyTo(buffer); + trailierContinuation.PayloadSequence.CopyTo(buffer.AsSpan(trailerFrame.PayloadLength)); + _hpackDecoder.Decode(buffer, endHeaders: true, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(trailerValue, _decodedHeaders["too_long"]); + } + + [Fact] + public async Task ResponseTrailers_ShortHeadersBeforeSingleLong_MultipleRequests_ShortHeadersInDynamicTable() + { + var trailerValue = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + await InitializeConnectionAsync(async context => + { + await context.Response.WriteAsync("Hello World"); + context.Response.AppendTrailer("a-key", "a-value"); + context.Response.AppendTrailer("b-key", "b-value"); + context.Response.AppendTrailer("too_long", trailerValue); + }); + + // Request 1 + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, withFlags: (byte)Http2DataFrameFlags.NONE, - withStreamId: 0); + withStreamId: 1); + + var request1TrailerFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 30, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + var request1TrailierContinuation1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + var request1TrailierContinuation2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 13, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(request1TrailerFrame.PayloadSequence, endHeaders: false, handler: this); + Assert.Equal("a-value", _decodedHeaders["a-key"]); + Assert.Equal("b-value", _decodedHeaders["b-key"]); + + _decodedHeaders.Clear(); + _hpackDecoder.Decode(request1TrailierContinuation1.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + _hpackDecoder.Decode(request1TrailierContinuation2.PayloadSequence, endHeaders: true, handler: this); + Assert.Equal(trailerValue, _decodedHeaders["too_long"]); + + // Request 2 + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 2, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + var request2TrailerFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 2, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 3); + + var request2TrailierContinuation1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 3); + + var request2TrailierContinuation2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 13, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + + _hpackDecoder.Decode(request2TrailerFrame.PayloadSequence, endHeaders: false, handler: this); + Assert.Equal("a-value", _decodedHeaders["a-key"]); + Assert.Equal("b-value", _decodedHeaders["b-key"]); + + _decodedHeaders.Clear(); + _hpackDecoder.Decode(request2TrailierContinuation1.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + _hpackDecoder.Decode(request2TrailierContinuation2.PayloadSequence, endHeaders: true, handler: this); + Assert.Equal(trailerValue, _decodedHeaders["too_long"]); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false).DefaultTimeout(); + } + + [Fact] + public async Task ResponseTrailers_DoubleLong_SplitsTrailersToContinuationFrames() + { + var trailerValue = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + await InitializeConnectionAsync(async context => + { + await context.Response.WriteAsync("Hello World"); + context.Response.AppendTrailer("too_long", trailerValue); + context.Response.AppendTrailer("too_long2", trailerValue); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + var frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 13, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Equal(trailerValue, _decodedHeaders["too_long"]); + _decodedHeaders.Clear(); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 14, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: true, handler: this); + Assert.Equal(trailerValue, _decodedHeaders["too_long2"]); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false).DefaultTimeout(); + } + + [Fact] + public async Task ResponseTrailers_ShortThenLongThenShort_SplitsTrailers() + { + var trailerValue = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + string shortValue = "testValue"; + await InitializeConnectionAsync(async context => + { + await context.Response.WriteAsync("Hello World"); + context.Response.AppendTrailer("short", shortValue); + context.Response.AppendTrailer("long", trailerValue); + context.Response.AppendTrailer("short2", shortValue); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + var trailerFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 17, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + _hpackDecoder.Decode(trailerFrame.PayloadSequence, endHeaders: false, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(shortValue, _decodedHeaders["short"]); + _decodedHeaders.Clear(); + + var trailierContinuation1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + _hpackDecoder.Decode(trailierContinuation1.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + var trailierContinuation2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 27, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(trailierContinuation2.PayloadSequence, endHeaders: true, handler: this); + Assert.Equal(trailerValue, _decodedHeaders["long"]); + Assert.Equal(shortValue, _decodedHeaders["short2"]); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false).DefaultTimeout(); + } + + [Fact] + public async Task LongResponseHeader_FollowedBy_LongResponseTrailer_SplitsTrailersToContinuationFrames() + { + var value = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + await InitializeConnectionAsync(async context => + { + context.Response.Headers["too_long_header"] = value; + await context.Response.WriteAsync("Hello World"); + context.Response.AppendTrailer("too_long_trailer", value); + }); + + // Stream 1 + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + // Response headers + var frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Equal(2, _decodedHeaders.Count); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("Sat, 01 Jan 2000 00:00:00 GMT", _decodedHeaders["date"]); + _decodedHeaders.Clear(); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 20, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: true, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(value, _decodedHeaders["too_long_header"]); + _decodedHeaders.Clear(); + + // Data + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + // Trailers + frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 21, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); - VerifyGoAway(goAway, int.MaxValue, Http2ErrorCode.INTERNAL_ERROR); + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: true, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(value, _decodedHeaders["too_long_trailer"]); + _decodedHeaders.Clear(); - _pair.Application.Output.Complete(); - await _connectionTask; + // Stream 3 + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); - var message = Assert.Single(LogMessages, m => m.Exception is HPackEncodingException); - Assert.Contains(SR.net_http_hpack_encode_failure, message.Exception.Message); + // Response headers + frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 2, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 3); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Equal(2, _decodedHeaders.Count); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("Sat, 01 Jan 2000 00:00:00 GMT", _decodedHeaders["date"]); + _decodedHeaders.Clear(); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 3); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 20, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: true, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(value, _decodedHeaders["too_long_header"]); + _decodedHeaders.Clear(); + + // Data + await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + // Trailers + frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 3); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: false, handler: this); + Assert.Empty(_decodedHeaders); + + frame = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 21, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + + _hpackDecoder.Decode(frame.PayloadSequence, endHeaders: true, handler: this); + Assert.Single(_decodedHeaders); + Assert.Equal(value, _decodedHeaders["too_long_trailer"]); + _decodedHeaders.Clear(); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false).DefaultTimeout(); } [Fact] @@ -3183,13 +3518,237 @@ await ExpectAsync(Http2FrameType.DATA, } [Fact] - public async Task ResponseWithHeadersTooLarge_AbortsConnection() + public async Task ResponseWithHeaderValueTooLarge_SplitsHeaderToContinuationFrames() { - var appFinished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + await InitializeConnectionAsync(async context => + { + context.Response.Headers.ETag = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + await context.Response.WriteAsync("Hello World"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + // Just the StatusCode gets written before aborting in the continuation frame + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame3 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 5, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + + var temp = new byte[headersFrame.PayloadSequence.Length + headersFrame2.PayloadSequence.Length + headersFrame3.PayloadSequence.Length]; + headersFrame.PayloadSequence.CopyTo(temp.AsSpan()); + headersFrame2.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length)); + headersFrame3.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length + (int)headersFrame2.PayloadSequence.Length)); + + _hpackDecoder.Decode(temp, endHeaders: true, handler: this); + Assert.Equal((int)Http2PeerSettings.DefaultMaxFrameSize, _decodedHeaders[HeaderNames.ETag].Length); + } + + [Fact] + public async Task TooLargeHeaderFollowedByContinuationHeaders_Split() + { await InitializeConnectionAsync(async context => { - context.Response.Headers["too_long"] = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + context.Response.Headers.ETag = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + context.Response.Headers.TE = new string('a', 30); + await context.Response.WriteAsync("Hello World"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + // Just the StatusCode gets written before aborting in the continuation frame + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame3 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 40, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + + var temp = new byte[headersFrame.PayloadSequence.Length + headersFrame2.PayloadSequence.Length + headersFrame3.PayloadSequence.Length]; + headersFrame.PayloadSequence.CopyTo(temp.AsSpan()); + headersFrame2.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length)); + headersFrame3.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length + (int)headersFrame2.PayloadSequence.Length)); + + _hpackDecoder.Decode(temp, endHeaders: true, handler: this); + Assert.Equal((int)Http2PeerSettings.DefaultMaxFrameSize, _decodedHeaders[HeaderNames.ETag].Length); + Assert.Equal(30, _decodedHeaders[HeaderNames.TE].Length); + } + + [Fact] + public async Task TwoTooLargeHeaderFollowedByContinuationHeaders_Split() + { + await InitializeConnectionAsync(async context => + { + context.Response.Headers.ETag = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + context.Response.Headers.TE = new string('b', (int)Http2PeerSettings.DefaultMaxFrameSize); + await context.Response.WriteAsync("Hello World"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + var frames = new Http2FrameWithPayload[5]; + frames[0] = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + frames[1] = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + frames[2] = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 5, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + frames[3] = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + frames[4] = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 7, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + + var totalSize = frames.Sum(x => x.PayloadSequence.Length); + var temp = new byte[totalSize]; + var destinationIndex = 0; + for (var i = 0; i < frames.Length; i++) + { + frames[i].PayloadSequence.CopyTo(temp.AsSpan(destinationIndex)); + destinationIndex += (int)frames[i].PayloadSequence.Length; + } + _hpackDecoder.Decode(temp, endHeaders: true, handler: this); + Assert.Equal((int)Http2PeerSettings.DefaultMaxFrameSize, _decodedHeaders[HeaderNames.ETag].Length); + Assert.Equal((int)Http2PeerSettings.DefaultMaxFrameSize, _decodedHeaders[HeaderNames.TE].Length); + } + + [Fact] + public async Task ClientRequestedLargerFrame_HeadersSplitByRequestedSize() + { + _clientSettings.MaxFrameSize = 17000; + _serviceContext.ServerOptions.Limits.Http2.MaxFrameSize = 17001; + await InitializeConnectionAsync(async context => + { + context.Response.Headers.ETag = new string('a', 17002); + await context.Response.WriteAsync("Hello World"); + }, expectedSettingsCount: 5); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + // Just the StatusCode gets written before aborting in the continuation frame + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 17000, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 8, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task ResponseWithMultipleHeaderValueTooLargeForFrame_SplitsHeaderToContinuationFrames() + { + await InitializeConnectionAsync(async context => + { + // This size makes it fit to a single header, but not next to the response status etc. + context.Response.Headers.ETag = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize - 20); + await context.Response.WriteAsync("Hello World"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + // Just the StatusCode gets written before aborting in the continuation frame + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16369, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + + var temp = new byte[headersFrame.PayloadSequence.Length + headersFrame2.PayloadSequence.Length]; + headersFrame.PayloadSequence.CopyTo(temp.AsSpan()); + headersFrame2.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length)); + + _hpackDecoder.Decode(temp, endHeaders: true, handler: this); + Assert.Equal((int)Http2PeerSettings.DefaultMaxFrameSize - 20, _decodedHeaders[HeaderNames.ETag].Length); + } + + [Fact] + public async Task ResponseWithHeaderNameTooLarge_SplitsHeaderToContinuationFrames() + { + var longHeaderName = new string('a', (int)Http2PeerSettings.DefaultMaxFrameSize); + var headerValue = "some value"; + await InitializeConnectionAsync(async context => + { + context.Response.Headers[longHeaderName] = headerValue; + await context.Response.WriteAsync("Hello World"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + // Just the StatusCode gets written before aborting in the continuation frame + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var headersFrame3 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 15, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + + var temp = new byte[headersFrame.PayloadSequence.Length + headersFrame2.PayloadSequence.Length + headersFrame3.PayloadSequence.Length]; + headersFrame.PayloadSequence.CopyTo(temp.AsSpan()); + headersFrame2.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length)); + headersFrame3.PayloadSequence.CopyTo(temp.AsSpan((int)headersFrame.PayloadSequence.Length + (int)headersFrame2.PayloadSequence.Length)); + + _hpackDecoder.Decode(temp, endHeaders: true, handler: this); + Assert.Equal(headerValue, _decodedHeaders[longHeaderName]); + } + + [Fact] + public async Task ResponseHeader_OneMegaByte_SplitsHeaderToContinuationFrames() + { + int frameSize = (int)Http2PeerSettings.DefaultMaxFrameSize; + int count = 64; + var headerValue = new string('a', frameSize * count); // 1 MB value + await InitializeConnectionAsync(async context => + { + context.Response.Headers["my"] = headerValue; await context.Response.WriteAsync("Hello World"); }); @@ -3200,11 +3759,21 @@ await ExpectAsync(Http2FrameType.HEADERS, withLength: 32, withFlags: (byte)Http2HeadersFrameFlags.NONE, withStreamId: 1); + for (int i = 0; i < count; i++) + { + await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16384, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + } - _pair.Application.Output.Complete(); + // One more frame because of the header name + size of header value + size header name + 2 * H encoding + await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 8, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); - await WaitForConnectionErrorAsync(ignoreNonGoAwayFrames: false, expectedLastStreamId: int.MaxValue, Http2ErrorCode.INTERNAL_ERROR, - SR.net_http_hpack_encode_failure); + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); } [Fact] diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs index 70f47cf2e62b..c334ee588dab 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs @@ -14,12 +14,12 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.FlowControl; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Time.Testing; @@ -849,7 +849,7 @@ internal async Task SendHeadersAsync(int streamId, Http2HeadersFrameFlags Http2FrameWriter.WriteHeader(frame, outputWriter); await SendAsync(buffer.Span.Slice(0, length)); - return done; + return done == HeaderWriteResult.Done; } internal Task SendHeadersAsync(int streamId, Http2HeadersFrameFlags flags, IEnumerable> headers) @@ -919,7 +919,7 @@ internal async Task SendContinuationAsync(int streamId, Http2ContinuationF Http2FrameWriter.WriteHeader(frame, outputWriter); await SendAsync(buffer.Span.Slice(0, length)); - return done; + return done == HeaderWriteResult.Done; } internal async Task SendContinuationAsync(int streamId, Http2ContinuationFrameFlags flags, byte[] payload) @@ -947,7 +947,7 @@ internal async Task SendContinuationAsync(int streamId, Http2ContinuationF Http2FrameWriter.WriteHeader(frame, outputWriter); await SendAsync(buffer.Span.Slice(0, length)); - return done; + return done == HeaderWriteResult.Done; } internal Http2HeadersEnumerator GetHeadersEnumerator(IEnumerable> headers) diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs index 2d58d859b81b..86af821432b6 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2/Http2RequestTests.cs @@ -89,6 +89,46 @@ public async Task GET_Metrics_HttpProtocolAndTlsSet() } } + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + public async Task GET_LargeResponseHeader_Success(bool largeValue, bool largeKey) + { + // Arrange + var longKey = "key-" + new string('$', largeKey ? 128 * 1024 : 1); + var longValue = "value-" + new string('!', largeValue ? 128 * 1024 : 1); + var builder = CreateHostBuilder( + c => + { + c.Response.Headers["test"] = "abc"; + c.Response.Headers[longKey] = longValue; + return Task.CompletedTask; + }, + protocol: HttpProtocols.Http2, + plaintext: true); + + using (var host = builder.Build()) + { + await host.StartAsync(); + var client = HttpHelpers.CreateClient(maxResponseHeadersLength: 1024); + + // Act + var request1 = new HttpRequestMessage(HttpMethod.Get, $"http://127.0.0.1:{host.GetPort()}/"); + request1.Version = HttpVersion.Version20; + request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response = await client.SendAsync(request1, CancellationToken.None); + response.EnsureSuccessStatusCode(); + + // Assert + Assert.Equal("abc", response.Headers.GetValues("test").Single()); + Assert.Equal(longValue, response.Headers.GetValues(longKey).Single()); + + await host.StopAsync(); + } + } + [Fact] public async Task GET_NoTLS_Http11RequestToHttp2Endpoint_400Result() { diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs index cc6f7bacb4f2..92bb131c4ed4 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/HttpHelpers.cs @@ -35,7 +35,7 @@ public static HttpProtocolException GetProtocolException(this Exception ex) throw new Exception($"Couldn't find {nameof(HttpProtocolException)}. Original error: {ex}"); } - public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, TimeSpan? expect100ContinueTimeout = null, bool includeClientCert = false) + public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, TimeSpan? expect100ContinueTimeout = null, bool includeClientCert = false, int? maxResponseHeadersLength = null) { var handler = new SocketsHttpHandler(); handler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions @@ -55,6 +55,11 @@ public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, Time handler.PooledConnectionIdleTimeout = idleTimeout.Value; } + if (maxResponseHeadersLength != null) + { + handler.MaxResponseHeadersLength = maxResponseHeadersLength.Value; + } + return new HttpMessageInvoker(handler); }