|
29 | 29 | // Copyright (c) 2007-2025 Broadcom. All Rights Reserved.
|
30 | 30 | //---------------------------------------------------------------------------
|
31 | 31 |
|
| 32 | +using System; |
| 33 | +using System.Runtime.CompilerServices; |
32 | 34 | using System.Threading;
|
33 | 35 | using System.Threading.Tasks;
|
| 36 | +using System.Threading.Tasks.Sources; |
34 | 37 |
|
35 | 38 | namespace RabbitMQ.Client.Impl
|
36 | 39 | {
|
37 |
| - sealed class AsyncManualResetEvent |
| 40 | + sealed class AsyncManualResetEvent(bool initialState = false) : IValueTaskSource |
38 | 41 | {
|
39 |
| - volatile TaskCompletionSource<bool> _taskCompletionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); |
| 42 | + private readonly object _lock = new(); |
40 | 43 |
|
41 |
| - public AsyncManualResetEvent(bool initialState = false) |
| 44 | + private State _state = initialState ? State.Set : State.Reset; |
| 45 | + |
| 46 | + // Do not make this field readonly |
| 47 | + private ManualResetValueTaskSourceCore<bool> _valueTaskSource = new() { RunContinuationsAsynchronously = true }; |
| 48 | + |
| 49 | + public bool IsSet => (State)Volatile.Read(ref Unsafe.As<State, byte>(ref _state)) == State.Set; |
| 50 | + |
| 51 | + public async ValueTask WaitAsync(CancellationToken cancellationToken = default) |
42 | 52 | {
|
43 |
| - if (initialState) |
| 53 | + ValueTask valueTask; |
| 54 | + CancellationTokenRegistration tokenRegistration = default; |
| 55 | + |
| 56 | + lock (_lock) |
| 57 | + { |
| 58 | + // If already set, return immediately |
| 59 | + if (_state == State.Set) |
| 60 | + { |
| 61 | + return; |
| 62 | + } |
| 63 | + |
| 64 | + // Create the ValueTask with current version |
| 65 | + valueTask = new ValueTask(this, _valueTaskSource.Version); |
| 66 | + |
| 67 | + // Only transition to Awaiting if we're in Reset state |
| 68 | + if (_state == State.Reset) |
| 69 | + { |
| 70 | + _state = State.Awaiting; |
| 71 | + |
| 72 | + // Register cancellation if token can be cancelled |
| 73 | + if (cancellationToken.CanBeCanceled) |
| 74 | + { |
| 75 | +#if NET |
| 76 | + tokenRegistration = cancellationToken.UnsafeRegister( |
| 77 | + static state => |
| 78 | + { |
| 79 | + (AsyncManualResetEvent amre, CancellationToken token) = |
| 80 | + (Tuple<AsyncManualResetEvent, CancellationToken>)state!; |
| 81 | + amre.SetCancelled(token); |
| 82 | + }, state: Tuple.Create(this, cancellationToken)); |
| 83 | +#else |
| 84 | + tokenRegistration = cancellationToken.Register( |
| 85 | + static state => |
| 86 | + { |
| 87 | + (AsyncManualResetEvent amre, CancellationToken token) = |
| 88 | + (Tuple<AsyncManualResetEvent, CancellationToken>)state!; |
| 89 | + amre.SetCancelled(token); |
| 90 | + }, |
| 91 | + state: Tuple.Create(this, cancellationToken), useSynchronizationContext: false); |
| 92 | +#endif |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + try |
44 | 98 | {
|
45 |
| - _taskCompletionSource.SetResult(true); |
| 99 | + await valueTask.ConfigureAwait(false); |
| 100 | + } |
| 101 | + finally |
| 102 | + { |
| 103 | + // Dispose cancellation registration outside of lock to avoid deadlock |
| 104 | + if (tokenRegistration != default) |
| 105 | + { |
| 106 | +#if NET |
| 107 | + await tokenRegistration.DisposeAsync().ConfigureAwait(false); |
| 108 | +#else |
| 109 | + tokenRegistration.Dispose(); |
| 110 | +#endif |
| 111 | + } |
46 | 112 | }
|
47 | 113 | }
|
48 | 114 |
|
49 |
| - public bool IsSet => _taskCompletionSource.Task.IsCompleted; |
50 |
| - |
51 |
| - public Task WaitAsync(CancellationToken cancellationToken = default) |
| 115 | + public void Set() |
52 | 116 | {
|
53 |
| - Task<bool> task = _taskCompletionSource.Task; |
54 |
| - return task.IsCompleted ? task : task.WaitAsync(cancellationToken); |
55 |
| - } |
| 117 | + lock (_lock) |
| 118 | + { |
| 119 | + State previousState = _state; |
| 120 | + _state = State.Set; |
56 | 121 |
|
57 |
| - public void Set() => _taskCompletionSource.TrySetResult(true); |
| 122 | + // Only set result if we were in Awaiting state |
| 123 | + if (previousState == State.Awaiting) |
| 124 | + { |
| 125 | + _valueTaskSource.SetResult(true); |
| 126 | + } |
| 127 | + } |
| 128 | + } |
58 | 129 |
|
59 | 130 | public void Reset()
|
60 | 131 | {
|
61 |
| - while (true) |
| 132 | + lock (_lock) |
62 | 133 | {
|
63 |
| - TaskCompletionSource<bool> currentTcs = _taskCompletionSource; |
64 |
| - if (!currentTcs.Task.IsCompleted || |
65 |
| - Interlocked.CompareExchange(ref _taskCompletionSource, new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously), currentTcs) == currentTcs) |
| 134 | + if (_state == State.Set) |
66 | 135 | {
|
67 |
| - return; |
| 136 | + _state = State.Reset; |
| 137 | + _valueTaskSource.Reset(); |
| 138 | + } |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + void SetCancelled(CancellationToken cancellationToken) |
| 143 | + { |
| 144 | + lock (_lock) |
| 145 | + { |
| 146 | + if (_state == State.Awaiting) |
| 147 | + { |
| 148 | + _state = State.Reset; // Reset to allow future waits |
| 149 | + _valueTaskSource.SetException(new OperationCanceledException(cancellationToken)); |
68 | 150 | }
|
69 | 151 | }
|
70 | 152 | }
|
| 153 | + |
| 154 | + void IValueTaskSource.GetResult(short token) => _valueTaskSource.GetResult(token); |
| 155 | + |
| 156 | + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _valueTaskSource.GetStatus(token); |
| 157 | + |
| 158 | + void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, |
| 159 | + ValueTaskSourceOnCompletedFlags flags) => _valueTaskSource.OnCompleted(continuation, state, token, flags); |
| 160 | + |
| 161 | + enum State : byte |
| 162 | + { |
| 163 | + Reset, |
| 164 | + Set, |
| 165 | + Awaiting |
| 166 | + } |
71 | 167 | }
|
72 | 168 | }
|
0 commit comments