|
29 | 29 | // Copyright (c) 2007-2025 Broadcom. All Rights Reserved.
|
30 | 30 | //---------------------------------------------------------------------------
|
31 | 31 |
|
| 32 | +using System; |
| 33 | +using System.Runtime.CompilerServices; |
32 | 34 | using System.Threading;
|
33 | 35 | using System.Threading.Tasks;
|
| 36 | +using System.Threading.Tasks.Sources; |
34 | 37 |
|
35 | 38 | namespace RabbitMQ.Client.Impl
|
36 | 39 | {
|
37 |
| - sealed class AsyncManualResetEvent |
| 40 | + sealed class AsyncManualResetEvent(bool initialState = false) : IValueTaskSource |
38 | 41 | {
|
39 |
| - volatile TaskCompletionSource<bool> _taskCompletionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); |
| 42 | + private readonly object _lock = new(); |
40 | 43 |
|
41 |
| - public AsyncManualResetEvent(bool initialState = false) |
| 44 | + private State _state = initialState ? State.Set : State.Reset; |
| 45 | + |
| 46 | + // Do not make this field readonly |
| 47 | + private ManualResetValueTaskSourceCore<bool> _valueTaskSource = new() { RunContinuationsAsynchronously = true }; |
| 48 | + |
| 49 | + public bool IsSet => (State)Volatile.Read(ref Unsafe.As<State, byte>(ref _state)) == State.Set; |
| 50 | + |
| 51 | + public async ValueTask WaitAsync(CancellationToken cancellationToken = default) |
42 | 52 | {
|
43 |
| - if (initialState) |
| 53 | + ValueTask valueTask; |
| 54 | + CancellationTokenRegistration tokenRegistration = default; |
| 55 | + |
| 56 | + lock (_lock) |
| 57 | + { |
| 58 | + // If already set, return immediately |
| 59 | + if (_state == State.Set) |
| 60 | + { |
| 61 | + return; |
| 62 | + } |
| 63 | + |
| 64 | + // Create the ValueTask with current version |
| 65 | + valueTask = new ValueTask(this, _valueTaskSource.Version); |
| 66 | + |
| 67 | + // Only transition to Awaiting if we're in Reset state |
| 68 | + if (_state == State.Reset) |
| 69 | + { |
| 70 | + _state = State.Awaiting; |
| 71 | + |
| 72 | + // Register cancellation if token can be cancelled |
| 73 | + if (cancellationToken.CanBeCanceled) |
| 74 | + { |
| 75 | + tokenRegistration = cancellationToken.UnsafeRegister( |
| 76 | + static state => |
| 77 | + { |
| 78 | + (AsyncManualResetEvent amre, CancellationToken token) = |
| 79 | + (Tuple<AsyncManualResetEvent, CancellationToken>)state!; |
| 80 | + amre.SetCancelled(token); |
| 81 | + }, Tuple.Create(this, cancellationToken)); |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + try |
44 | 87 | {
|
45 |
| - _taskCompletionSource.SetResult(true); |
| 88 | + await valueTask.ConfigureAwait(false); |
| 89 | + } |
| 90 | + finally |
| 91 | + { |
| 92 | + // Dispose cancellation registration outside of lock to avoid deadlock |
| 93 | + if (tokenRegistration != default) |
| 94 | + { |
| 95 | +#if NET |
| 96 | + await tokenRegistration.DisposeAsync().ConfigureAwait(false); |
| 97 | +#else |
| 98 | + tokenRegistration.Dispose(); |
| 99 | +#endif |
| 100 | + } |
46 | 101 | }
|
47 | 102 | }
|
48 | 103 |
|
49 |
| - public bool IsSet => _taskCompletionSource.Task.IsCompleted; |
50 |
| - |
51 |
| - public Task WaitAsync(CancellationToken cancellationToken = default) |
| 104 | + public void Set() |
52 | 105 | {
|
53 |
| - Task<bool> task = _taskCompletionSource.Task; |
54 |
| - return task.IsCompleted ? task : task.WaitAsync(cancellationToken); |
55 |
| - } |
| 106 | + lock (_lock) |
| 107 | + { |
| 108 | + State previousState = _state; |
| 109 | + _state = State.Set; |
56 | 110 |
|
57 |
| - public void Set() => _taskCompletionSource.TrySetResult(true); |
| 111 | + // Only set result if we were in Awaiting state |
| 112 | + if (previousState == State.Awaiting) |
| 113 | + { |
| 114 | + _valueTaskSource.SetResult(true); |
| 115 | + } |
| 116 | + } |
| 117 | + } |
58 | 118 |
|
59 | 119 | public void Reset()
|
60 | 120 | {
|
61 |
| - while (true) |
| 121 | + lock (_lock) |
62 | 122 | {
|
63 |
| - TaskCompletionSource<bool> currentTcs = _taskCompletionSource; |
64 |
| - if (!currentTcs.Task.IsCompleted || |
65 |
| - Interlocked.CompareExchange(ref _taskCompletionSource, new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously), currentTcs) == currentTcs) |
| 123 | + if (_state == State.Set) |
66 | 124 | {
|
67 |
| - return; |
| 125 | + _state = State.Reset; |
| 126 | + _valueTaskSource.Reset(); |
| 127 | + } |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + void SetCancelled(CancellationToken cancellationToken) |
| 132 | + { |
| 133 | + lock (_lock) |
| 134 | + { |
| 135 | + if (_state == State.Awaiting) |
| 136 | + { |
| 137 | + _state = State.Reset; // Reset to allow future waits |
| 138 | + _valueTaskSource.SetException(new OperationCanceledException(cancellationToken)); |
68 | 139 | }
|
69 | 140 | }
|
70 | 141 | }
|
| 142 | + |
| 143 | + void IValueTaskSource.GetResult(short token) => _valueTaskSource.GetResult(token); |
| 144 | + |
| 145 | + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _valueTaskSource.GetStatus(token); |
| 146 | + |
| 147 | + void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, |
| 148 | + ValueTaskSourceOnCompletedFlags flags) => _valueTaskSource.OnCompleted(continuation, state, token, flags); |
| 149 | + |
| 150 | + enum State : byte |
| 151 | + { |
| 152 | + Reset, |
| 153 | + Set, |
| 154 | + Awaiting |
| 155 | + } |
71 | 156 | }
|
72 | 157 | }
|
0 commit comments