Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ladeak committed Apr 23, 2024
1 parent 7033ec7 commit e7e932d
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 86 deletions.
83 changes: 54 additions & 29 deletions src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ internal sealed class Http2FrameWriter
// This is only set to true by tests.
private readonly bool _scheduleInline;

private uint _maxFrameSize = Http2PeerSettings.MinAllowedMaxFrameSize;
private byte[] _headerEncodingBuffer;
private int _maxFrameSize = Http2PeerSettings.MinAllowedMaxFrameSize;
private readonly ArrayBufferWriter<byte> _headerEncodingBuffer;
private long _unflushedBytes;

private bool _completed;
Expand Down Expand Up @@ -107,7 +107,7 @@ public Http2FrameWriter(
_flusher = new TimingPipeFlusher(timeoutControl, serviceContext.Log);
_flusher.Initialize(_outputWriter);
_outgoingFrame = new Http2Frame();
_headerEncodingBuffer = new byte[_maxFrameSize];
_headerEncodingBuffer = new ArrayBufferWriter<byte>(_maxFrameSize);

_scheduleInline = serviceContext.Scheduler == PipeScheduler.Inline;

Expand Down Expand Up @@ -373,8 +373,9 @@ public void UpdateMaxFrameSize(uint maxFrameSize)
{
if (_maxFrameSize != maxFrameSize)
{
_maxFrameSize = maxFrameSize;
_headerEncodingBuffer = new byte[_maxFrameSize];
// Safe cast, MaxFrameSize is limited to 2^24-1 bytes by the protocol and by Http2PeerSettings.
// Ref: https://datatracker.ietf.org/doc/html/rfc7540#section-4.2
_maxFrameSize = (int)maxFrameSize;
}
}
}
Expand Down Expand Up @@ -485,8 +486,11 @@ private void WriteResponseHeadersUnsynchronized(int streamId, int statusCode, Ht
{
_headersEnumerator.Initialize(headers);
_outgoingFrame.PrepareHeaders(headerFrameFlags, streamId);
var buffer = _headerEncodingBuffer.AsSpan();
_headerEncodingBuffer.ResetWrittenCount();
var buffer = _headerEncodingBuffer.GetSpan(_maxFrameSize)[0.._maxFrameSize];
var done = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
Debug.Assert(done != HeaderWriteResult.BufferTooSmall, "Oversized frames should not be returned, beucase this always writes the status.");
_headerEncodingBuffer.Advance(payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
}
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
Expand Down Expand Up @@ -524,10 +528,18 @@ private ValueTask<FlushResult> WriteDataAndTrailersAsync(Http2Stream stream, in

try
{
_headersEnumerator.Initialize(headers);
_outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_STREAM, streamId);
var buffer = _headerEncodingBuffer.AsSpan();
var done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
HeaderWriteResult done = HeaderWriteResult.MoreHeaders;
int payloadLength;
do
{
_headersEnumerator.Initialize(headers);
_headerEncodingBuffer.ResetWrittenCount();
var bufferSize = done == HeaderWriteResult.BufferTooSmall ? _headerEncodingBuffer.Capacity * 2 : _headerEncodingBuffer.Capacity;
var buffer = _headerEncodingBuffer.GetSpan(bufferSize)[0..bufferSize];
done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength);
} while (done == HeaderWriteResult.BufferTooSmall);
_headerEncodingBuffer.Advance(payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
}
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
Expand All @@ -542,32 +554,45 @@ private ValueTask<FlushResult> WriteDataAndTrailersAsync(Http2Stream stream, in
}
}

private void FinishWritingHeaders(int streamId, int payloadLength, bool done)
private void SplitHeaderFramesToOutput(int streamId, HeaderWriteResult done, bool isFramePrepared)
{
var buffer = _headerEncodingBuffer.AsSpan();
_outgoingFrame.PayloadLength = payloadLength;
if (done)
var dataToFrame = _headerEncodingBuffer.WrittenSpan;
var shouldPrepareFrame = !isFramePrepared;
while (dataToFrame.Length > 0)
{
_outgoingFrame.HeadersFlags |= Http2HeadersFrameFlags.END_HEADERS;
}

WriteHeaderUnsynchronized();
_outputWriter.Write(buffer.Slice(0, payloadLength));

while (!done)
{
_outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId);

done = HPackHeaderWriter.ContinueEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength);
_outgoingFrame.PayloadLength = payloadLength;
if (shouldPrepareFrame)
{
_outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId);
}
else
{
shouldPrepareFrame = true;
}

if (done)
var currentSize = dataToFrame.Length > _maxFrameSize ? _maxFrameSize : dataToFrame.Length;
_outgoingFrame.PayloadLength = currentSize;
if (done == HeaderWriteResult.Done && dataToFrame.Length == currentSize)
{
_outgoingFrame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS;
_outgoingFrame.HeadersFlags |= Http2HeadersFrameFlags.END_HEADERS;
}

WriteHeaderUnsynchronized();
_outputWriter.Write(buffer.Slice(0, payloadLength));
_outputWriter.Write(dataToFrame[..currentSize]);
dataToFrame = dataToFrame.Slice(currentSize);
}
}

private void FinishWritingHeaders(int streamId, int payloadLength, HeaderWriteResult done)
{
SplitHeaderFramesToOutput(streamId, done, isFramePrepared: true);
while (done != HeaderWriteResult.Done)
{
_headerEncodingBuffer.ResetWrittenCount();
var bufferSize = done == HeaderWriteResult.BufferTooSmall ? _headerEncodingBuffer.Capacity * 2 : _headerEncodingBuffer.Capacity;
var buffer = _headerEncodingBuffer.GetSpan(bufferSize)[0..bufferSize];
done = HPackHeaderWriter.ContinueEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out payloadLength);
_headerEncodingBuffer.Advance(payloadLength);
SplitHeaderFramesToOutput(streamId, done, isFramePrepared: false);
}
}

Expand Down Expand Up @@ -994,4 +1019,4 @@ private void EnqueueWaitingForMoreConnectionWindow(Http2OutputProducer producer)
_http2Connection.Abort(new ConnectionAbortedException("HTTP/2 connection exceeded the outgoing flow control maximum queue size."));
}
}
}
}
94 changes: 76 additions & 18 deletions src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void BeginEncodeHeaders_Status302_NewIndexValue()
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));

var result = buffer.Slice(0, length).ToArray();
var hex = BitConverter.ToString(result);
Expand All @@ -52,7 +52,7 @@ public void BeginEncodeHeaders_CacheControlPrivate_NewIndexValue()
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));

var result = buffer.Slice(5, length - 5).ToArray();
var hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -81,7 +81,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit()

// First response
enumerator.Initialize(headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));

var result = buffer.Slice(0, length).ToArray();
var hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -123,7 +123,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit()

// Second response
enumerator.Initialize(headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length));

result = buffer.Slice(0, length).ToArray();
hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -164,7 +164,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeExceeded_EvictionsToFit()
headers.SetCookie = "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1";

enumerator.Initialize(headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length));

result = buffer.Slice(0, length).ToArray();
hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -225,7 +225,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction

// First response
enumerator.Initialize((HttpResponseHeaders)headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(302, hpackEncoder, enumerator, buffer, out var length));

var result = buffer.Slice(0, length).ToArray();
var hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -267,7 +267,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction

// Second response
enumerator.Initialize(headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(307, hpackEncoder, enumerator, buffer, out length));

result = buffer.Slice(0, length).ToArray();
hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -308,7 +308,7 @@ public void BeginEncodeHeadersCustomEncoding_MaxHeaderTableSizeExceeded_Eviction
headers.SetCookie = "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1";

enumerator.Initialize(headers);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out length));

result = buffer.Slice(0, length).ToArray();
hex = BitConverter.ToString(result);
Expand Down Expand Up @@ -366,7 +366,7 @@ public void BeginEncodeHeaders_ExcludedHeaders_NotAddedToTable(string headerName
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder(maxHeaderTableSize: Http2PeerSettings.DefaultHeaderTableSize);
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out _));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out _));

if (neverIndex)
{
Expand All @@ -392,7 +392,7 @@ public void BeginEncodeHeaders_HeaderExceedHeaderTableSize_NoIndexAndNoHeaderEnt
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length));

Assert.Empty(GetHeaderEntries(hpackEncoder));
}
Expand Down Expand Up @@ -482,11 +482,11 @@ public void EncodesHeadersInSinglePayloadWhenSpaceAvailable(KeyValuePair<string,
var length = 0;
if (statusCode.HasValue)
{
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(statusCode.Value, hpackEncoder, GetHeadersEnumerator(headers), payload, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(statusCode.Value, hpackEncoder, GetHeadersEnumerator(headers), payload, out length));
}
else
{
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, GetHeadersEnumerator(headers), payload, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, GetHeadersEnumerator(headers), payload, out length));
}
Assert.Equal(expectedPayload.Length, length);

Expand Down Expand Up @@ -548,28 +548,28 @@ public void EncodesHeadersInMultiplePayloadsWhenSpaceNotAvailable(bool exactSize

// When !exactSize, slices are one byte short of fitting the next header
var sliceLength = expectedStatusCodePayload.Length + (exactSize ? 0 : expectedDateHeaderPayload.Length - 1);
Assert.False(HPackHeaderWriter.BeginEncodeHeaders(statusCode, hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out var length));
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(statusCode, hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out var length));
Assert.Equal(expectedStatusCodePayload.Length, length);
Assert.Equal(expectedStatusCodePayload, payload.Slice(0, length).ToArray());

offset += length;

sliceLength = expectedDateHeaderPayload.Length + (exactSize ? 0 : expectedContentTypeHeaderPayload.Length - 1);
Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(expectedDateHeaderPayload.Length, length);
Assert.Equal(expectedDateHeaderPayload, payload.Slice(offset, length).ToArray());

offset += length;

sliceLength = expectedContentTypeHeaderPayload.Length + (exactSize ? 0 : expectedServerHeaderPayload.Length - 1);
Assert.False(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(expectedContentTypeHeaderPayload.Length, length);
Assert.Equal(expectedContentTypeHeaderPayload, payload.Slice(offset, length).ToArray());

offset += length;

sliceLength = expectedServerHeaderPayload.Length;
Assert.True(HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.ContinueEncodeHeaders(hpackEncoder, headerEnumerator, payload.Slice(offset, sliceLength), out length));
Assert.Equal(expectedServerHeaderPayload.Length, length);
Assert.Equal(expectedServerHeaderPayload, payload.Slice(offset, length).ToArray());
}
Expand All @@ -586,7 +586,7 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeUpdated_SizeUpdateInHeaders()

// First request
enumerator.Initialize(new Dictionary<string, StringValues>());
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length));

Assert.Equal(2, length);

Expand All @@ -600,11 +600,69 @@ public void BeginEncodeHeaders_MaxHeaderTableSizeUpdated_SizeUpdateInHeaders()

// Second request
enumerator.Initialize(new Dictionary<string, StringValues>());
Assert.True(HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out length));
Assert.Equal(HeaderWriteResult.Done, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out length));

Assert.Equal(0, length);
}

[Fact]
public void WithStatusCode_TooLargeHeader_ReturnsNotDone()
{
Span<byte> buffer = new byte[1024 * 16];

IHeaderDictionary headers = new HttpResponseHeaders();
headers.Cookie = new string('a', buffer.Length + 1);
var enumerator = new Http2HeadersEnumerator();
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length));
}

[Fact]
public void NoStatusCodeLargeHeader_ReturnsOversized()
{
Span<byte> buffer = new byte[1024 * 16];

IHeaderDictionary headers = new HttpResponseHeaders();
headers.Cookie = new string('a', buffer.Length + 1);
var enumerator = new Http2HeadersEnumerator();
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.Equal(HeaderWriteResult.BufferTooSmall, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length));
}

Check failure on line 635 in src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs

View check run for this annotation

Azure Pipelines / aspnetcore-quarantined-pr (Tests: Ubuntu x64)

src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs#L635

src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs(635,1): error IDE2000: (NETCORE_ENGINEERING_TELEMETRY=Build) Avoid multiple blank lines (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide2000)

Check failure on line 635 in src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs

View check run for this annotation

Azure Pipelines / aspnetcore-ci (Build Test: Ubuntu x64)

src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs#L635

src/Servers/Kestrel/Core/test/Http2/Http2HPackEncoderTests.cs(635,1): error IDE2000: (NETCORE_ENGINEERING_TELEMETRY=Build) Avoid multiple blank lines (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide2000)

[Fact]
public void WithStatusCode_JustFittingHeaderNoSpace_ReturnsNotDone()
{
Span<byte> buffer = new byte[1024 * 16];

IHeaderDictionary headers = new HttpResponseHeaders();
headers.Cookie = new string('a', buffer.Length - 1);
var enumerator = new Http2HeadersEnumerator();
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(200, hpackEncoder, enumerator, buffer, out var length));
}

[Fact]
public void NoStatusCode_JustFittingHeaderNoSpace_ReturnsNotDone()
{
Span<byte> buffer = new byte[1024 * 16];

IHeaderDictionary headers = new HttpResponseHeaders();
headers.Accept = "application/json;";
headers.Cookie = new string('a', buffer.Length - 1);
var enumerator = new Http2HeadersEnumerator();
enumerator.Initialize(headers);

var hpackEncoder = new DynamicHPackEncoder();
Assert.Equal(HeaderWriteResult.MoreHeaders, HPackHeaderWriter.BeginEncodeHeaders(hpackEncoder, enumerator, buffer, out var length));
}

private static Http2HeadersEnumerator GetHeadersEnumerator(IEnumerable<KeyValuePair<string, string>> headers)
{
var groupedHeaders = headers
Expand Down
Loading

0 comments on commit e7e932d

Please sign in to comment.