Skip to content

Commit

Permalink
Streams ensure cts cancel fix statecheck (akkadotnet#6935)
Browse files Browse the repository at this point in the history
* Added AsyncEnumerableSpec.AsyncEnumerableSource_Disposes_OnCancel to confirm failure and provide future regression checks

* Fix AsyncEnumerable Stage disposal as well as fast path race condition

* Add regression unit test case

* Make sure that any exception during shutdown is logged

* Make PostStop non-blocking

* Improve code readability by wrapping code in async local method

---------

Co-authored-by: Aaron Stannard <[email protected]>
Co-authored-by: Gregorius Soedharmo <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2023
1 parent 1af82a7 commit 9bf14af
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 25 deletions.
178 changes: 178 additions & 0 deletions src/core/Akka.Streams.Tests/Dsl/AsyncEnumerableSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Akka.Actor;
using Akka.Pattern;
using Akka.Streams.Dsl;
using Akka.Streams.TestKit;
Expand Down Expand Up @@ -269,6 +270,150 @@ await EventFilter.Warning().ExpectAsync(0, async () =>
});
}

/// <summary>
/// Reproduction for https://github.com/akkadotnet/akka.net/issues/6903
/// </summary>
[Fact(DisplayName = "AsyncEnumerable Source should dispose underlying async enumerator on kill switch signal")]
public async Task AsyncEnumerableSource_Disposes_On_KillSwitch()
{
await this.AssertAllStagesStoppedAsync(async () =>
{
var probe = CreateTestProbe();
var enumerable = new TestAsyncEnumerable(500.Milliseconds());
var src = Source.From(() => enumerable)
.ViaMaterialized(KillSwitches.Single<int>(), Keep.Right)
.ToMaterialized(Sink.ActorRefWithAck<int>(probe, "init", "ack", "complete"), Keep.Left);
var killSwitch = src.Run(Materializer);

// assert init was sent
await probe.ExpectMsgAsync<string>(msg => msg == "init");
probe.Sender.Tell("ack");

// assert enumerator is working
foreach (var i in Enumerable.Range(0, 5))
{
await probe.ExpectMsgAsync<int>(msg => msg == i);
probe.Sender.Tell("ack");
}

// last message was not ack-ed
await probe.ExpectMsgAsync<int>(msg => msg == 5);

killSwitch.Shutdown();

// assert that enumerable resource was disposed
await AwaitConditionAsync(() => enumerable.Disposed);
}, Materializer);
}

[Fact(DisplayName = "AsyncEnumerable Source should dispose underlying async enumerator on kill switch signal even after ActorSystem termination")]
public async Task AsyncEnumerableSource_Disposes_On_KillSwitch2()
{
var probe = CreateTestProbe();
// A long disposing enumerable source
var enumerable = new TestAsyncEnumerable(2.Seconds());
var src = Source.From(() => enumerable)
.ViaMaterialized(KillSwitches.Single<int>(), Keep.Right)
.ToMaterialized(Sink.ActorRefWithAck<int>(probe, "init", "ack", "complete"), Keep.Left);
var killSwitch = src.Run(Materializer);

// assert init was sent
await probe.ExpectMsgAsync<string>(msg => msg == "init");
probe.Sender.Tell("ack");

// assert enumerator is working
foreach (var i in Enumerable.Range(0, 5))
{
await probe.ExpectMsgAsync<int>(msg => msg == i);
probe.Sender.Tell("ack");
}

// last message was not ack-ed
await probe.ExpectMsgAsync<int>(msg => msg == 5);

killSwitch.Shutdown();

await Sys.Terminate();

// enumerable was not disposed even after system termination
enumerable.Disposed.Should().BeFalse();

// assert that enumerable resource can still be disposed even after system termination
// (Not guaranteed if process was already killed)
await AwaitConditionAsync(() => enumerable.Disposed);
}

private class TestAsyncEnumerable: IAsyncEnumerable<int>
{
private readonly AsyncEnumerator _enumerator;

public bool Disposed => _enumerator.Disposed;

public TestAsyncEnumerable(TimeSpan shutdownDelay)
{
_enumerator = new AsyncEnumerator(shutdownDelay);
}

public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken token = default)
{
token.ThrowIfCancellationRequested();
return _enumerator;
}

private sealed class AsyncEnumerator: IAsyncEnumerator<int>
{
private readonly TimeSpan _shutdownDelay;
private int _current = -1;

public AsyncEnumerator(TimeSpan shutdownDelay)
{
_shutdownDelay = shutdownDelay;
}

public bool Disposed { get; private set; }

public async ValueTask DisposeAsync()
{
await Task.Delay(_shutdownDelay);
Disposed = true;
}

public async ValueTask<bool> MoveNextAsync()
{
await Task.Delay(100);
_current++;
return true;
}

public int Current
{
get
{
if (_current == -1)
throw new IndexOutOfRangeException("MoveNextAsync has not been called");
if (Disposed)
throw new ObjectDisposedException("Enumerator already disposed");
return _current;
}
}
}
}

[Fact]
public async Task AsyncEnumerableSource_Disposes_OnCancel()
{
var resource = new Resource();
var tcs = new System.Threading.Tasks.TaskCompletionSource<NotUsed>(TaskCreationOptions
.RunContinuationsAsynchronously);
var src = Source.From(() =>
CancelTestGenerator(tcs, resource, default));
src.To(Sink.Ignore<int>()).Run(Materializer);
await tcs.Task;
Materializer.Shutdown();
await Task.Delay(500);
Assert.False(resource.IsActive);
}

private static async IAsyncEnumerable<int> RangeAsync(int start, int count,
[EnumeratorCancellation] CancellationToken token = default)
{
Expand Down Expand Up @@ -308,6 +453,39 @@ private static async IAsyncEnumerable<int> ProbeableRangeAsync(int start, int co
yield return i;
}
}

public static async IAsyncEnumerable<int> CancelTestGenerator(
TaskCompletionSource<NotUsed> tcs,
Resource resource,
[EnumeratorCancellation] CancellationToken token
)
{
await using var res = resource;
int i = 0;
bool isSet = false;
while (true)
{
await Task.Delay(1, token).ConfigureAwait(false);
yield return i++;
if (isSet == false)
{
tcs.TrySetResult(NotUsed.Instance);
isSet = true;
}
}
// ReSharper disable once IteratorNeverReturns
}

public class Resource : IAsyncDisposable
{
public bool IsActive = true;
public ValueTask DisposeAsync()
{
IsActive = false;
Console.WriteLine("Enumerator completed and resource disposed");
return new ValueTask();
}
}
}
#endif
}
74 changes: 49 additions & 25 deletions src/core/Akka.Streams/Implementation/Fusing/Ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3812,18 +3812,19 @@ private sealed class Logic : OutGraphStageLogic
private readonly Action<T> _onSuccess;
private readonly Action<Exception> _onFailure;
private readonly Action _onComplete;
private readonly CancellationTokenSource _completionCts;

private CancellationTokenSource _completionCts;
private IAsyncEnumerator<T> _enumerator;

public Logic(SourceShape<T> shape, IAsyncEnumerable<T> enumerable) : base(shape)
{

_enumerable = enumerable;
_outlet = shape.Outlet;
_onSuccess = GetAsyncCallback<T>(OnSuccess);
_onFailure = GetAsyncCallback<Exception>(OnFailure);
_onComplete = GetAsyncCallback(OnComplete);

_completionCts = new CancellationTokenSource();
SetHandler(_outlet, this);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -3838,9 +3839,50 @@ public Logic(SourceShape<T> shape, IAsyncEnumerable<T> enumerable) : base(shape)
public override void PreStart()
{
base.PreStart();
_completionCts = new CancellationTokenSource();
_enumerator = _enumerable.GetAsyncEnumerator(_completionCts.Token);
}

public override void PostStop()
{
try
{
_completionCts.Cancel();
_completionCts.Dispose();
}
catch(Exception ex)
{
// This should never happen
Log.Debug(ex, "AsyncEnumerable threw while cancelling CancellationTokenSource");
}

try
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
// Intentionally creating a detached dispose task
DisposeEnumeratorAsync();
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}
catch(Exception ex)
{
Log.Debug(ex, "Underlying async enumerator threw an exception while being disposed.");
}
base.PostStop();
return;

async Task DisposeEnumeratorAsync()
{
try
{
await _enumerator.DisposeAsync();
}
catch (Exception ex)
{
// This is best effort exception logging, this log will never appear if the ActorSystem
// was shut down before we reach this code (BusEvent was not emitting new logs anymore)
Log.Debug(ex, "Underlying async enumerator threw an exception while being disposed.");
}
}
}

public override void OnPull()
{
Expand All @@ -3859,26 +3901,12 @@ public override void OnPull()
// if result is false, it means enumerator was closed. Complete stage in that case.
CompleteStage();
}
}
else if (vtask.IsCompleted) // IsCompleted covers Faulted, Cancelled, and RanToCompletion async state
{
// vtask will always contains an exception because we know we're not successful and always throws
try
{
// This does not block because we know that the task already completed
// Using GetAwaiter().GetResult() to automatically unwraps AggregateException inner exception
vtask.GetAwaiter().GetResult();
}
catch (Exception ex)
{
FailStage(ex);
return;
}

throw new InvalidOperationException("Should never reach this code");
}
else
{
//We immediately fall into wait case.
//Unlike Task, we don't have a 'status' Enum to switch off easily,
//And Error cases can just live with the small cost of async callback.
async Task ProcessTask()
{
// Since this Action is used as task continuation, we cannot safely call corresponding
Expand All @@ -3897,16 +3925,12 @@ async Task ProcessTask()
}
}

#pragma warning disable CS4014
ProcessTask();
#pragma warning restore CS4014
_ = ProcessTask();
}
}

public override void OnDownstreamFinish(Exception cause)
{
_completionCts.Cancel();
_completionCts.Dispose();
CompleteStage();
base.OnDownstreamFinish(cause);
}
Expand Down

0 comments on commit 9bf14af

Please sign in to comment.