diff --git a/src/DotNext.Tests/Threading/LinkedCancellationTokenSourceTests.cs b/src/DotNext.Tests/Threading/LinkedCancellationTokenSourceTests.cs index aa0a86645..a252b3665 100644 --- a/src/DotNext.Tests/Threading/LinkedCancellationTokenSourceTests.cs +++ b/src/DotNext.Tests/Threading/LinkedCancellationTokenSourceTests.cs @@ -45,4 +45,37 @@ public static async Task DirectCancellation() Equal(linked.CancellationOrigin, linked.Token); } } + + [Fact] + public static async Task CancellationWithTimeout() + { + using var source1 = new CancellationTokenSource(); + var token = new CancellationToken(canceled: false); + using var cts = token.LinkTo(DefaultTimeout, source1.Token); + NotNull(cts); + source1.Cancel(); + + await token.WaitAsync(); + } + + [Fact] + public static async Task ConcurrentCancellation() + { + using var source1 = new CancellationTokenSource(); + using var source2 = new CancellationTokenSource(); + using var source3 = new CancellationTokenSource(); + var token = source3.Token; + + using var cts = token.LinkTo([source1.Token, source2.Token]); + NotNull(cts); + ThreadPool.UnsafeQueueUserWorkItem(Cancel, source1, preferLocal: false); + ThreadPool.UnsafeQueueUserWorkItem(Cancel, source2, preferLocal: false); + ThreadPool.UnsafeQueueUserWorkItem(Cancel, source3, preferLocal: false); + + await token.WaitAsync(); + + Contains(cts.CancellationOrigin, new[] { source1.Token, source2.Token, source3.Token }); + + static void Cancel(CancellationTokenSource cts) => cts.Cancel(); + } } \ No newline at end of file diff --git a/src/DotNext.Threading/Threading/LinkedTokenSourceFactory.cs b/src/DotNext.Threading/Threading/LinkedTokenSourceFactory.cs index 3c9c0b45f..3650e3128 100644 --- a/src/DotNext.Threading/Threading/LinkedTokenSourceFactory.cs +++ b/src/DotNext.Threading/Threading/LinkedTokenSourceFactory.cs @@ -1,3 +1,5 @@ +using System.Buffers; +using DotNext.Buffers; using Debug = System.Diagnostics.Debug; namespace DotNext.Threading; @@ -8,7 +10,7 @@ namespace DotNext.Threading; public static class LinkedTokenSourceFactory { /// - /// Links two cancellation tokens. + /// Links two cancellation tokens together. /// /// The first cancellation token. Can be modified by this method. /// The second cancellation token. @@ -34,6 +36,36 @@ public static class LinkedTokenSourceFactory return result; } + /// + /// Links multiple cancellation tokens together. + /// + /// The first cancellation token. Can be modified by this method. + /// A list of cancellation tokens to link together. + /// The linked token source; or if or are not cancelable. + public static LinkedCancellationTokenSource? LinkTo(this ref CancellationToken first, ReadOnlySpan tokens) // TODO: Add params + { + LinkedCancellationTokenSource? result; + if (tokens.IsEmpty) + { + result = null; + } + else + { + result = new MultipleLinkedCancellationTokenSource(tokens, out var isEmpty, first); + if (isEmpty) + { + result.Dispose(); + result = null; + } + else + { + first = result.Token; + } + } + + return result; + } + /// /// Links cancellation token with the timeout. /// @@ -117,11 +149,60 @@ protected override void Dispose(bool disposing) { if (disposing) { - registration1.Dispose(); - registration2.Dispose(); + registration1.Unregister(); + registration2.Unregister(); } base.Dispose(disposing); } } + + private sealed class MultipleLinkedCancellationTokenSource : LinkedCancellationTokenSource + { + private MemoryOwner registrations; + + internal MultipleLinkedCancellationTokenSource(ReadOnlySpan tokens, out bool isEmpty, CancellationToken first) + { + Debug.Assert(!tokens.IsEmpty); + + var writer = new BufferWriterSlim(tokens.Length); + try + { + foreach (var token in tokens) + { + if (token != first && token.CanBeCanceled) + { + writer.Add(token.UnsafeRegister(CancellationCallback, this)); + } + } + + if (first.CanBeCanceled && writer.WrittenCount > 0) + { + writer.Add(first.UnsafeRegister(CancellationCallback, this)); + } + + registrations = writer.DetachOrCopyBuffer(); + isEmpty = registrations.IsEmpty; + } + finally + { + writer.Dispose(); + } + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + foreach (ref readonly var registration in registrations.Span) + { + registration.Unregister(); + } + + registrations.Dispose(); + } + + base.Dispose(disposing); + } + } } \ No newline at end of file