Skip to content

Commit

Permalink
Cancel pending tasks on Dispose
Browse files Browse the repository at this point in the history
  • Loading branch information
mayuki committed Jun 7, 2024
1 parent 2e30768 commit 8151868
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 11 deletions.
7 changes: 4 additions & 3 deletions src/Multicaster/Remoting/DynamicRemoteProxyFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ static Core()
var (methodInvoke, cancellationTokenIndex) = MethodInvokeHelper.GetInvokeWithResultMethodInfo(method);
var il = methodBuilder.GetILGenerator();

var local_ctDefault = default(LocalBuilder);
if (!cancellationTokenIndex.HasValue)
{
il.DeclareLocal(typeof(CancellationToken));
local_ctDefault = il.DeclareLocal(typeof(CancellationToken));
}

il.Emit(OpCodes.Ldarg_0); // this
Expand All @@ -99,9 +100,9 @@ static Core()
}
else
{
il.Emit(OpCodes.Ldloca_S, 0);
il.Emit(OpCodes.Ldloca_S, local_ctDefault!);
il.Emit(OpCodes.Initobj, typeof(CancellationToken));
il.Emit(OpCodes.Ldloc_0);
il.Emit(OpCodes.Ldloc_S, local_ctDefault!);
}

il.Emit(OpCodes.Callvirt, methodInvoke); // base.Invoke(method.Name, methodId, arg1, arg2 ...);
Expand Down
41 changes: 36 additions & 5 deletions src/Multicaster/Remoting/IRemoteClientResultPendingTaskRegistry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

namespace Cysharp.Runtime.Multicast.Remoting;

public interface IRemoteClientResultPendingTaskRegistry
public interface IRemoteClientResultPendingTaskRegistry : IDisposable
{
void Register(PendingTask pendingTask);
bool TryGetAndUnregisterPendingTask(Guid messageId, [NotNullWhen(true)] out PendingTask? pendingTask);
PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer);
PendingTask CreateTask(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer);
PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, TaskCompletionSource<TResult> taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer);
PendingTask CreateTask(string methodName, int methodId, Guid messageId, TaskCompletionSource taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer);
}

public class RemoteClientResultPendingTaskRegistry : IRemoteClientResultPendingTaskRegistry
{
private readonly ConcurrentDictionary<Guid, (PendingTask Task, IDisposable CancelRegistration)> _pendingTasks = new();
private readonly TimeSpan _timeout;
private bool _disposed;

public int Count => _pendingTasks.Count; // for unit tests

Expand All @@ -23,14 +24,16 @@ public RemoteClientResultPendingTaskRegistry(TimeSpan? timeout = default)
_timeout = timeout ?? TimeSpan.FromSeconds(5);
}

public PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
public PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, TaskCompletionSource<TResult> taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
=> PendingTask.Create<TResult>(methodName, methodId, messageId, taskCompletionSource, timeoutCancellationToken.CanBeCanceled ? timeoutCancellationToken : new CancellationTokenSource(_timeout).Token, serializer);

public PendingTask CreateTask(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
public PendingTask CreateTask(string methodName, int methodId, Guid messageId, TaskCompletionSource taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
=> PendingTask.Create(methodName, methodId, messageId, taskCompletionSource, timeoutCancellationToken.CanBeCanceled ? timeoutCancellationToken : new CancellationTokenSource(_timeout).Token, serializer);

public void Register(PendingTask pendingTask)
{
ThrowIfDisposed();

var registration = pendingTask.TimeoutCancellationToken.Register(() =>
{
pendingTask.TrySetCanceled(pendingTask.TimeoutCancellationToken);
Expand All @@ -53,4 +56,32 @@ public bool TryGetAndUnregisterPendingTask(Guid messageId, [NotNullWhen(true)] o
}
return removed;
}

public void Dispose()
{
_disposed = true;

DisposeAll:
foreach (var pendingTask in _pendingTasks)
{
if (_pendingTasks.TryRemove(pendingTask.Key, out _))
{
pendingTask.Value.CancelRegistration.Dispose();
pendingTask.Value.Task.TrySetCanceled();
}
}

if (!_pendingTasks.IsEmpty)
{
goto DisposeAll;
}
}

private void ThrowIfDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(IRemoteClientResultPendingTaskRegistry));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ public void Register(PendingTask pendingTask)
public bool TryGetAndUnregisterPendingTask(Guid messageId, [NotNullWhen(true)] out PendingTask? pendingTask)
=> throw new NotSupportedException("The group does not support client results.");

public PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
public PendingTask CreateTask<TResult>(string methodName, int methodId, Guid messageId, TaskCompletionSource<TResult> taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
=> throw new NotSupportedException("The group does not support client results.");

public PendingTask CreateTask(string methodName, int methodId, Guid messageId, object taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
public PendingTask CreateTask(string methodName, int methodId, Guid messageId, TaskCompletionSource taskCompletionSource, CancellationToken timeoutCancellationToken, IRemoteSerializer serializer)
=> throw new NotSupportedException("The group does not support client results.");

public void Dispose()
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Cysharp.Runtime.Multicast.Remoting;

namespace Multicaster.Tests;

public class RemoteClientResultPendingTaskRegistryTest
{
[Fact]
public void CancelAll_On_Dispose()
{
// Arrange
var reg = new RemoteClientResultPendingTaskRegistry();
var serializer = new TestJsonRemoteSerializer();
var tcs1 = new TaskCompletionSource<bool>();
var pendingTask1 = reg.CreateTask("Foo", 0, Guid.NewGuid(), tcs1, default, serializer);
reg.Register(pendingTask1);
var tcs2 = new TaskCompletionSource<bool>();
var pendingTask2 = reg.CreateTask("Foo", 0, Guid.NewGuid(), tcs2, default, serializer);
reg.Register(pendingTask2);
var tcs3 = new TaskCompletionSource();
var pendingTask3 = reg.CreateTask("Bar", 0, Guid.NewGuid(), tcs3, default, serializer);
reg.Register(pendingTask3);

// Act
reg.Dispose();

// Assert
Assert.True(tcs1.Task.IsCanceled);
Assert.True(tcs2.Task.IsCanceled);
Assert.True(tcs3.Task.IsCanceled);
}

[Fact]
public async Task Timeout()
{
// Arrange
using var reg = new RemoteClientResultPendingTaskRegistry(TimeSpan.FromMilliseconds(500));
var serializer = new TestJsonRemoteSerializer();
var tcs1 = new TaskCompletionSource<bool>();
var pendingTask1 = reg.CreateTask("Foo", 0, Guid.NewGuid(), tcs1, default, serializer);
reg.Register(pendingTask1);
var tcs2 = new TaskCompletionSource<bool>();
var pendingTask2 = reg.CreateTask("Foo", 0, Guid.NewGuid(), tcs2, default, serializer);
reg.Register(pendingTask2);
var tcs3 = new TaskCompletionSource();
var pendingTask3 = reg.CreateTask("Bar", 0, Guid.NewGuid(), tcs3, new CancellationTokenSource(TimeSpan.FromMilliseconds(10)).Token, serializer);
reg.Register(pendingTask3);

// Act
await Task.Delay(100);
var beforeSecondDelayTcs1IsCanceled = tcs1.Task.IsCanceled;
var beforeSecondDelayTcs2IsCanceled = tcs2.Task.IsCanceled;
var beforeSecondDelayTcs3IsCanceled = tcs3.Task.IsCanceled;
await Task.Delay(600);

// Assert
Assert.False(beforeSecondDelayTcs1IsCanceled);
Assert.False(beforeSecondDelayTcs2IsCanceled);
Assert.True(beforeSecondDelayTcs3IsCanceled); // The timeout of the pending task is overridden.
Assert.True(tcs1.Task.IsCanceled);
Assert.True(tcs2.Task.IsCanceled);
Assert.True(tcs3.Task.IsCanceled);
}
}
2 changes: 1 addition & 1 deletion test/Multicaster.Tests/RemoteGroupClientResultTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public async Task Cancellation()
Assert.NotEmpty(receiverWriterA.Written);
var invocationMessage = JsonSerializer.Deserialize<TestJsonRemoteSerializer.SerializedInvocation>(receiverWriterA.Written[0])!;
Assert.Equal(nameof(ITestReceiver.ClientResult_Cancellation), invocationMessage.MethodName);
Assert.Equal(1, invocationMessage.Arguments.Count);
Assert.Single(invocationMessage.Arguments);
Assert.Equal(5000, ((JsonElement)invocationMessage.Arguments[0]!).GetInt32());

Assert.Equal(0, pendingTasks.Count);
Expand Down

0 comments on commit 8151868

Please sign in to comment.