Skip to content

Commit 758b6fd

Browse files
authored
Fix the shutdown timeout hung (#2230)
* Fix the shutdown timeout hung * Fix tests
1 parent c38a0b9 commit 758b6fd

File tree

12 files changed

+87
-60
lines changed

12 files changed

+87
-60
lines changed

src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ public override Task CloseClientConnections(CancellationToken token)
7676
throw new NotSupportedException();
7777
}
7878

79-
protected override Task<ConnectionContext> CreateConnection(string target = null)
79+
protected override Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
8080
{
81-
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target);
81+
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken);
8282
}
8383

8484
protected override Task DisposeConnection(ConnectionContext connection)

src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
55
using System.Collections.Generic;
66
using System.Security.Claims;
7+
using System.Threading;
78
using System.Threading.Tasks;
89

910
namespace Microsoft.Azure.SignalR;
@@ -34,5 +35,5 @@ public LocalTokenProvider(
3435
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
3536
}
3637

37-
public Task<string> ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
38+
public Task<string> ProvideAsync(CancellationToken cancellationToken) => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm, cancellationToken);
3839
}

src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
5+
using System.Threading;
56
using System.Threading.Tasks;
67

78
namespace Microsoft.Azure.SignalR;
@@ -15,5 +16,5 @@ public MicrosoftEntraTokenProvider(MicrosoftEntraAccessKey accessKey)
1516
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
1617
}
1718

18-
public Task<string> ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync();
19+
public Task<string> ProvideAsync(CancellationToken cancellationToken) => _accessKey.GetMicrosoftEntraTokenAsync(cancellationToken);
1920
}
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

4+
using System.Threading;
45
using System.Threading.Tasks;
56

67
namespace Microsoft.Azure.SignalR;
78

89
internal interface IAccessTokenProvider
910
{
10-
Task<string> ProvideAsync();
11+
Task<string> ProvideAsync(CancellationToken cancellationToken = default);
1112
}

src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa
121121
// We don't need to capture to a local because we never change this delegate.
122122
if (_accessTokenProvider != null)
123123
{
124-
accessToken = await _accessTokenProvider.ProvideAsync();
124+
accessToken = await _accessTokenProvider.ProvideAsync(cancellationToken);
125125
if (!string.IsNullOrEmpty(accessToken))
126126
{
127127
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");

src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ internal abstract partial class ServiceConnectionBase : IServiceConnection
4444

4545
private readonly TaskCompletionSource<object> _serviceConnectionOfflineTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
4646

47+
private readonly CancellationTokenSource _connectionStartCts = new();
48+
4749
private readonly ServiceConnectionType _connectionType;
4850

4951
private readonly IServiceMessageHandler _serviceMessageHandler;
@@ -157,70 +159,75 @@ public async Task StartAsync(string target = null)
157159
}
158160

159161
Status = ServiceConnectionStatus.Connecting;
160-
161-
var connection = await EstablishConnectionAsync(target);
162-
if (connection != null)
162+
try
163163
{
164-
_connectionContext = connection;
165-
Status = ServiceConnectionStatus.Connected;
166-
_serviceConnectionStartTcs.TrySetResult(true);
167-
try
164+
var connection = await EstablishConnectionAsync(target, _connectionStartCts.Token);
165+
if (connection != null)
168166
{
169-
TimerAwaitable syncTimer = null;
167+
_connectionContext = connection;
168+
Status = ServiceConnectionStatus.Connected;
169+
_serviceConnectionStartTcs.TrySetResult(true);
170170
try
171171
{
172-
if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key)
172+
TimerAwaitable syncTimer = null;
173+
try
173174
{
174-
syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval);
175-
_ = UpdateAzureIdentityAsync(key, syncTimer);
175+
if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key)
176+
{
177+
syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval);
178+
_ = UpdateAzureIdentityAsync(key, syncTimer);
179+
}
180+
await ProcessIncomingAsync(connection);
176181
}
177-
await ProcessIncomingAsync(connection);
178-
}
179-
finally
180-
{
181-
// mark the status as Disconnected so that no one will write to this connection anymore
182-
Status = ServiceConnectionStatus.Disconnected;
183-
syncTimer?.Stop();
182+
finally
183+
{
184+
// mark the status as Disconnected so that no one will write to this connection anymore
185+
Status = ServiceConnectionStatus.Disconnected;
186+
syncTimer?.Stop();
184187

185-
// when ProcessIncoming completes, clean up the connection
188+
// when ProcessIncoming completes, clean up the connection
186189

187-
// TODO: Never cleanup connections unless Service asks us to do that
188-
// Current implementation is based on assumption that Service will drop clients
189-
// if server connection fails.
190-
await CleanupClientConnections();
190+
// TODO: Never cleanup connections unless Service asks us to do that
191+
// Current implementation is based on assumption that Service will drop clients
192+
// if server connection fails.
193+
await CleanupClientConnections();
194+
}
191195
}
192-
}
193-
catch (Exception ex)
194-
{
195-
Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex);
196-
}
197-
finally
198-
{
199-
// wait until all the connections are cleaned up to close the outgoing pipe
200-
// Don't allow write anymore when the connection is disconnected
201-
await _writeLock.WaitAsync();
202-
try
196+
catch (Exception ex)
203197
{
204-
// close the underlying connection
205-
await DisposeConnection(connection);
198+
Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex);
206199
}
207200
finally
208201
{
209-
_writeLock.Release();
202+
// wait until all the connections are cleaned up to close the outgoing pipe
203+
// Don't allow write anymore when the connection is disconnected
204+
await _writeLock.WaitAsync();
205+
try
206+
{
207+
// close the underlying connection
208+
await DisposeConnection(connection);
209+
}
210+
finally
211+
{
212+
_writeLock.Release();
213+
}
210214
}
211215
}
212216
}
213-
else
217+
finally
214218
{
215219
Status = ServiceConnectionStatus.Disconnected;
216220
_serviceConnectionStartTcs.TrySetResult(false);
221+
_serviceConnectionOfflineTcs.TrySetResult(false);
217222
}
218223
}
219224

220225
public Task StopAsync()
221226
{
222227
try
223228
{
229+
// to avoid the connection hung in connecting state
230+
_connectionStartCts.Cancel();
224231
_connectionContext?.Transport.Input.CancelPendingRead();
225232
}
226233
catch (Exception ex)
@@ -277,7 +284,7 @@ public virtual async Task<bool> SafeWriteAsync(ServiceMessage serviceMessage)
277284

278285
public abstract bool TryRemoveClientConnection(string connectionId, out IClientConnection connection);
279286

280-
protected abstract Task<ConnectionContext> CreateConnection(string target = null);
287+
protected abstract Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default);
281288

282289
protected abstract Task DisposeConnection(ConnectionContext connection);
283290

@@ -492,11 +499,11 @@ private Task OnFlowControlMessageAsync(ConnectionFlowControlMessage flowControlM
492499
throw new NotImplementedException($"Unsupported connection type: {flowControlMessage.ConnectionType}");
493500
}
494501

495-
private async Task<ConnectionContext> EstablishConnectionAsync(string target)
502+
private async Task<ConnectionContext> EstablishConnectionAsync(string target, CancellationToken cancellationToken)
496503
{
497504
try
498505
{
499-
var connectionContext = await CreateConnection(target);
506+
var connectionContext = await CreateConnection(target, cancellationToken);
500507
try
501508
{
502509
if (await HandshakeAsync(connectionContext))

src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ public async IAsyncEnumerable<Page<SignalRGroupConnection>> ListConnectionsInGro
318318
public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
319319
{
320320
_terminated = true;
321-
return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionAsync(c, mode, token)));
321+
return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionFromServiceAsync(c, mode, token)));
322322
}
323323

324324
public virtual Task CloseClientConnections(CancellationToken token)
@@ -476,8 +476,23 @@ protected virtual ServiceConnectionStatus GetStatus()
476476
: ServiceConnectionStatus.Disconnected;
477477
}
478478

479-
protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token)
479+
/// <summary>
480+
/// TODO: this logic sounds more fit into the serviceConnection class
481+
/// </summary>
482+
/// <param name="c">The service connection instance</param>
483+
/// <param name="mode">The graceful shutdown mode</param>
484+
/// <param name="token">The cancellation token</param>
485+
/// <returns></returns>
486+
protected async Task RemoveConnectionFromServiceAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token)
480487
{
488+
if (c.Status != ServiceConnectionStatus.Connected)
489+
{
490+
// if the connection is not yet connected
491+
// we stop the connection in case it is connecting
492+
// otherwise ConnectionOfflineTask should be set
493+
await c.StopAsync();
494+
return;
495+
}
481496
var retry = 0;
482497
while (retry < MaxRetryRemoveSeverConnection && !token.IsCancellationRequested)
483498
{

src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;

src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ public override async Task CloseClientConnections(CancellationToken token)
129129
}
130130
}
131131

132-
protected override Task<ConnectionContext> CreateConnection(string target = null)
132+
protected override Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
133133
{
134-
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target);
134+
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken);
135135
}
136136

137137
protected override Task DisposeConnection(ConnectionContext connection)

test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Collections.Concurrent;
6+
using System.Threading;
67
using System.Threading.Tasks;
78

89
using Microsoft.AspNetCore.Connections;
@@ -83,7 +84,7 @@ public override async Task<bool> SafeWriteAsync(ServiceMessage serviceMessage)
8384
return result;
8485
}
8586

86-
protected override async Task<ConnectionContext> CreateConnection(string target = null)
87+
protected override async Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
8788
{
8889
TestConnectionContext = await base.CreateConnection() as TestConnectionContext;
8990

0 commit comments

Comments
 (0)