Skip to content

Commit

Permalink
Clear IAsyncStreamReader<T>.Current value before reading next value (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Aug 21, 2023
1 parent 89dd217 commit a970fec
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
return Task.FromException<bool>(new InvalidOperationException("Can't read messages after the request is complete."));
}

// Clear current before moving next. This prevents rooting the previous value while getting the next one.
// In a long running stream this can allow the previous value to be GCed.
Current = null!;

var request = _serverCallContext.HttpContext.Request.BodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
if (!request.IsCompletedSuccessfully)
{
Expand Down
4 changes: 4 additions & 0 deletions src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)

CompatibilityHelpers.Assert(_grpcEncoding != null, "Encoding should have been calculated from response.");

// Clear current before moving next. This prevents rooting the previous value while getting the next one.
// In a long running stream this can allow the previous value to be GCed.
Current = null!;

var readMessage = await _call.ReadMessageAsync(
_responseStream,
_grpcEncoding,
Expand Down
38 changes: 38 additions & 0 deletions test/Grpc.AspNetCore.Server.Tests/HttpContextStreamReaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,42 @@ public async Task MoveNext_TokenCancelledDuringMoveNext_CancelTask()
Assert.AreEqual(1, testSink.Writes.Count);
Assert.AreEqual("ReadingMessage", testSink.Writes.First().EventId.Name);
}

[Test]
public async Task MoveNext_MultipleCalls_CurrentClearedBetweenCalls()
{
// Arrange
var ms = new SyncPointMemoryStream();

var testSink = new TestSink();
var testLoggerFactory = new TestLoggerFactory(testSink, enabled: true);

var httpContext = new DefaultHttpContext();
httpContext.Features.Set<IRequestBodyPipeFeature>(new TestRequestBodyPipeFeature(PipeReader.Create(ms)));
var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext, logger: testLoggerFactory.CreateLogger("Test"));
var reader = new HttpContextStreamReader<HelloReply>(serverCallContext, MessageHelpers.ServiceMethod.ResponseMarshaller.ContextualDeserializer);

// Act
var nextTask = reader.MoveNext(CancellationToken.None);

await ms.AddDataAndWait(new byte[]
{
0x00, // compression = 0
0x00,
0x00,
0x00,
0x00 // length = 0
}).DefaultTimeout();

Assert.IsTrue(await nextTask.DefaultTimeout());
Assert.IsNotNull(reader.Current);

nextTask = reader.MoveNext(CancellationToken.None);

Assert.IsFalse(nextTask.IsCompleted);
Assert.IsFalse(nextTask.IsCanceled);

// Assert
Assert.IsNull(reader.Current);
}
}
3 changes: 3 additions & 0 deletions test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(
var moveNextTask2 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask2.IsCompleted);

// Current is cleared after MoveNext is called.
Assert.IsNull(responseStream.Current);

await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply
{
Message = "Hello world 2"
Expand Down

0 comments on commit a970fec

Please sign in to comment.