Skip to content

Commit

Permalink
Added support of multiple tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
sakno committed Oct 24, 2024
1 parent a37db34 commit 4fa27d6
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 3 deletions.
33 changes: 33 additions & 0 deletions src/DotNext.Tests/Threading/LinkedCancellationTokenSourceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,37 @@ public static async Task DirectCancellation()
Equal(linked.CancellationOrigin, linked.Token);
}
}

[Fact]
public static async Task CancellationWithTimeout()
{
using var source1 = new CancellationTokenSource();
var token = new CancellationToken(canceled: false);
using var cts = token.LinkTo(DefaultTimeout, source1.Token);
NotNull(cts);
source1.Cancel();

await token.WaitAsync();
}

[Fact]
public static async Task ConcurrentCancellation()
{
using var source1 = new CancellationTokenSource();
using var source2 = new CancellationTokenSource();
using var source3 = new CancellationTokenSource();
var token = source3.Token;

using var cts = token.LinkTo([source1.Token, source2.Token]);
NotNull(cts);
ThreadPool.UnsafeQueueUserWorkItem(Cancel, source1, preferLocal: false);
ThreadPool.UnsafeQueueUserWorkItem(Cancel, source2, preferLocal: false);
ThreadPool.UnsafeQueueUserWorkItem(Cancel, source3, preferLocal: false);

await token.WaitAsync();

Contains(cts.CancellationOrigin, new[] { source1.Token, source2.Token, source3.Token });

static void Cancel(CancellationTokenSource cts) => cts.Cancel();
}
}
87 changes: 84 additions & 3 deletions src/DotNext.Threading/Threading/LinkedTokenSourceFactory.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Buffers;
using DotNext.Buffers;
using Debug = System.Diagnostics.Debug;

namespace DotNext.Threading;
Expand All @@ -8,7 +10,7 @@ namespace DotNext.Threading;
public static class LinkedTokenSourceFactory
{
/// <summary>
/// Links two cancellation tokens.
/// Links two cancellation tokens together.
/// </summary>
/// <param name="first">The first cancellation token. Can be modified by this method.</param>
/// <param name="second">The second cancellation token.</param>
Expand All @@ -34,6 +36,36 @@ public static class LinkedTokenSourceFactory
return result;
}

/// <summary>
/// Links multiple cancellation tokens together.
/// </summary>
/// <param name="first">The first cancellation token. Can be modified by this method.</param>
/// <param name="tokens">A list of cancellation tokens to link together.</param>
/// <returns>The linked token source; or <see langword="null"/> if <paramref name="first"/> or <paramref name="tokens"/> are not cancelable.</returns>
public static LinkedCancellationTokenSource? LinkTo(this ref CancellationToken first, ReadOnlySpan<CancellationToken> tokens) // TODO: Add params
{
LinkedCancellationTokenSource? result;
if (tokens.IsEmpty)
{
result = null;
}
else
{
result = new MultipleLinkedCancellationTokenSource(tokens, out var isEmpty, first);
if (isEmpty)
{
result.Dispose();
result = null;
}
else
{
first = result.Token;
}
}

return result;
}

/// <summary>
/// Links cancellation token with the timeout.
/// </summary>
Expand Down Expand Up @@ -117,11 +149,60 @@ protected override void Dispose(bool disposing)
{
if (disposing)
{
registration1.Dispose();
registration2.Dispose();
registration1.Unregister();
registration2.Unregister();
}

base.Dispose(disposing);
}
}

private sealed class MultipleLinkedCancellationTokenSource : LinkedCancellationTokenSource
{
private MemoryOwner<CancellationTokenRegistration> registrations;

internal MultipleLinkedCancellationTokenSource(ReadOnlySpan<CancellationToken> tokens, out bool isEmpty, CancellationToken first)
{
Debug.Assert(!tokens.IsEmpty);

var writer = new BufferWriterSlim<CancellationTokenRegistration>(tokens.Length);
try
{
foreach (var token in tokens)
{
if (token != first && token.CanBeCanceled)
{
writer.Add(token.UnsafeRegister(CancellationCallback, this));
}
}

if (first.CanBeCanceled && writer.WrittenCount > 0)
{
writer.Add(first.UnsafeRegister(CancellationCallback, this));
}

registrations = writer.DetachOrCopyBuffer();
isEmpty = registrations.IsEmpty;
}
finally
{
writer.Dispose();
}
}

protected override void Dispose(bool disposing)
{
if (disposing)
{
foreach (ref readonly var registration in registrations.Span)
{
registration.Unregister();
}

registrations.Dispose();
}

base.Dispose(disposing);
}
}
}

0 comments on commit 4fa27d6

Please sign in to comment.