diff --git a/src/Grpc.Net.Client/Balancer/BalancerAddress.cs b/src/Grpc.Net.Client/Balancer/BalancerAddress.cs index 0c6322cb1..826b7deac 100644 --- a/src/Grpc.Net.Client/Balancer/BalancerAddress.cs +++ b/src/Grpc.Net.Client/Balancer/BalancerAddress.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -30,7 +30,7 @@ namespace Grpc.Net.Client.Balancer; /// public sealed class BalancerAddress { - private BalancerAttributes? _attributes; + internal BalancerAttributes? _attributes; /// /// Initializes a new instance of the class with the specified . diff --git a/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs b/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs index 4d7069b19..7bb8526b3 100644 --- a/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs +++ b/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs @@ -38,20 +38,22 @@ public sealed class BalancerAttributes : IDictionary, IReadOnly /// /// Gets a read-only collection of metadata attributes. /// - public static readonly BalancerAttributes Empty = new BalancerAttributes(new ReadOnlyDictionary(new Dictionary())); + public static readonly BalancerAttributes Empty = new BalancerAttributes(new Dictionary(), readOnly: true); - private readonly IDictionary _attributes; + internal readonly Dictionary _attributes; + private readonly bool _readOnly; /// /// Initializes a new instance of the class. /// - public BalancerAttributes() : this(new Dictionary()) + public BalancerAttributes() : this(new Dictionary(), readOnly: false) { } - private BalancerAttributes(IDictionary attributes) + private BalancerAttributes(Dictionary attributes, bool readOnly) { _attributes = attributes; + _readOnly = readOnly; } object? IDictionary.this[string key] @@ -62,6 +64,7 @@ private BalancerAttributes(IDictionary attributes) } set { + ValidateReadOnly(); _attributes[key] = value; } } @@ -69,21 +72,41 @@ private BalancerAttributes(IDictionary attributes) ICollection IDictionary.Keys => _attributes.Keys; ICollection IDictionary.Values => _attributes.Values; int ICollection>.Count => _attributes.Count; - bool ICollection>.IsReadOnly => _attributes.IsReadOnly; + bool ICollection>.IsReadOnly => _readOnly ? true : ((ICollection>)_attributes).IsReadOnly; IEnumerable IReadOnlyDictionary.Keys => _attributes.Keys; IEnumerable IReadOnlyDictionary.Values => _attributes.Values; int IReadOnlyCollection>.Count => _attributes.Count; object? IReadOnlyDictionary.this[string key] => _attributes[key]; - void IDictionary.Add(string key, object? value) => _attributes.Add(key, value); - void ICollection>.Add(KeyValuePair item) => _attributes.Add(item); - void ICollection>.Clear() => _attributes.Clear(); + void IDictionary.Add(string key, object? value) + { + ValidateReadOnly(); + _attributes.Add(key, value); + } + void ICollection>.Add(KeyValuePair item) + { + ValidateReadOnly(); + ((ICollection>)_attributes).Add(item); + } + void ICollection>.Clear() + { + ValidateReadOnly(); + _attributes.Clear(); + } bool ICollection>.Contains(KeyValuePair item) => _attributes.Contains(item); bool IDictionary.ContainsKey(string key) => _attributes.ContainsKey(key); - void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => _attributes.CopyTo(array, arrayIndex); + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)_attributes).CopyTo(array, arrayIndex); IEnumerator> IEnumerable>.GetEnumerator() => _attributes.GetEnumerator(); IEnumerator System.Collections.IEnumerable.GetEnumerator() => ((System.Collections.IEnumerable)_attributes).GetEnumerator(); - bool IDictionary.Remove(string key) => _attributes.Remove(key); - bool ICollection>.Remove(KeyValuePair item) => _attributes.Remove(item); + bool IDictionary.Remove(string key) + { + ValidateReadOnly(); + return _attributes.Remove(key); + } + bool ICollection>.Remove(KeyValuePair item) + { + ValidateReadOnly(); + return ((ICollection>)_attributes).Remove(item); + } bool IDictionary.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value); bool IReadOnlyDictionary.ContainsKey(string key) => _attributes.ContainsKey(key); bool IReadOnlyDictionary.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value); @@ -121,6 +144,7 @@ public bool TryGetValue(BalancerAttributesKey key, [MaybeNullWhe /// The value. public void Set(BalancerAttributesKey key, TValue value) { + ValidateReadOnly(); _attributes[key.Key] = value; } @@ -135,10 +159,19 @@ public void Set(BalancerAttributesKey key, TValue value) /// public bool Remove(BalancerAttributesKey key) { + ValidateReadOnly(); return _attributes.Remove(key.Key); } - internal string DebuggerToString() + private void ValidateReadOnly() + { + if (_readOnly) + { + throw new NotSupportedException("Collection is read-only."); + } + } + + private string DebuggerToString() { return $"Count = {_attributes.Count}"; } diff --git a/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs b/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs index a1d0f2fcf..00d737ad1 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -44,6 +44,46 @@ public bool Equals(BalancerAddress? x, BalancerAddress? y) return false; } + var xAttributes = x._attributes?._attributes; + var yAttributes = y._attributes?._attributes; + if (!AttributesEqual(xAttributes, yAttributes)) + { + return false; + } + + return true; + } + + private bool AttributesEqual(Dictionary? x, Dictionary? y) + { + if (x == null && y == null) + { + return true; + } + + if (x == null || y == null) + { + return false; + } + + if (x.Count != y.Count) + { + return false; + } + + foreach (var kvp in x) + { + if (!y.TryGetValue(kvp.Key, out var value)) + { + return false; + } + + if (!Equals(kvp.Value, value)) + { + return false; + } + } + return true; } diff --git a/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs index ac9f468c5..f2c182649 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -17,6 +17,7 @@ #endregion #if SUPPORT_LOAD_BALANCING +using System.Net; using Grpc.Shared; namespace Grpc.Net.Client.Balancer.Internal; @@ -28,7 +29,7 @@ namespace Grpc.Net.Client.Balancer.Internal; /// internal interface ISubchannelTransport : IDisposable { - BalancerAddress? CurrentAddress { get; } + DnsEndPoint? CurrentEndPoint { get; } TimeSpan? ConnectTimeout { get; } #if NET5_0_OR_GREATER diff --git a/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs index f7bd0b047..980926d3f 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs @@ -35,19 +35,19 @@ namespace Grpc.Net.Client.Balancer.Internal; internal class PassiveSubchannelTransport : ISubchannelTransport, IDisposable { private readonly Subchannel _subchannel; - private BalancerAddress? _currentAddress; + private DnsEndPoint? _currentEndPoint; public PassiveSubchannelTransport(Subchannel subchannel) { _subchannel = subchannel; } - public BalancerAddress? CurrentAddress => _currentAddress; + public DnsEndPoint? CurrentEndPoint => _currentEndPoint; public TimeSpan? ConnectTimeout { get; } public void Disconnect() { - _currentAddress = null; + _currentEndPoint = null; _subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected."); } @@ -60,12 +60,12 @@ public void Disconnect() TryConnectAsync(ConnectContext context) { Debug.Assert(_subchannel._addresses.Count == 1); - Debug.Assert(CurrentAddress == null); + Debug.Assert(CurrentEndPoint == null); var currentAddress = _subchannel._addresses[0]; _subchannel.UpdateConnectivityState(ConnectivityState.Connecting, "Passively connecting."); - _currentAddress = currentAddress; + _currentEndPoint = currentAddress.EndPoint; _subchannel.UpdateConnectivityState(ConnectivityState.Ready, "Passively connected."); #if !NETSTANDARD2_0 && !NET462 @@ -77,7 +77,7 @@ public void Disconnect() public void Dispose() { - _currentAddress = null; + _currentEndPoint = null; } #if NET5_0_OR_GREATER diff --git a/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs index 4a28c8b50..71ede2624 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs @@ -63,11 +63,11 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi private int _lastEndPointIndex; internal Socket? _initialSocket; - private BalancerAddress? _initialSocketAddress; + private DnsEndPoint? _initialSocketEndPoint; private List>? _initialSocketData; private DateTime? _initialSocketCreatedTime; private bool _disposed; - private BalancerAddress? _currentAddress; + private DnsEndPoint? _currentEndPoint; public SocketConnectivitySubchannelTransport( Subchannel subchannel, @@ -88,7 +88,7 @@ public SocketConnectivitySubchannelTransport( } private object Lock => _subchannel.Lock; - public BalancerAddress? CurrentAddress => _currentAddress; + public DnsEndPoint? CurrentEndPoint => _currentEndPoint; public TimeSpan? ConnectTimeout { get; } // For testing. Take a copy under lock for thread-safety. @@ -127,16 +127,16 @@ private void DisconnectUnsynchronized() _initialSocket?.Dispose(); _initialSocket = null; - _initialSocketAddress = null; + _initialSocketEndPoint = null; _initialSocketData = null; _initialSocketCreatedTime = null; _lastEndPointIndex = 0; - _currentAddress = null; + _currentEndPoint = null; } public async ValueTask TryConnectAsync(ConnectContext context) { - Debug.Assert(CurrentAddress == null); + Debug.Assert(CurrentEndPoint == null); // Addresses could change while connecting. Make a copy of the subchannel's addresses. var addresses = _subchannel.GetAddresses(); @@ -162,10 +162,10 @@ public async ValueTask TryConnectAsync(ConnectContext context) lock (Lock) { - _currentAddress = currentAddress; + _currentEndPoint = currentAddress.EndPoint; _lastEndPointIndex = currentIndex; _initialSocket = socket; - _initialSocketAddress = currentAddress; + _initialSocketEndPoint = currentAddress.EndPoint; _initialSocketData = null; _initialSocketCreatedTime = DateTime.UtcNow; @@ -240,20 +240,28 @@ private void OnCheckSocketConnection(object? state) try { Socket? socket; - BalancerAddress? socketAddress; + DnsEndPoint? socketEndpoint; var closeSocket = false; Exception? checkException = null; lock (Lock) { socket = _initialSocket; - socketAddress = _initialSocketAddress; + socketEndpoint = _initialSocketEndPoint; if (socket != null) { - CompatibilityHelpers.Assert(socketAddress != null); + CompatibilityHelpers.Assert(socketEndpoint != null); - closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException); + var address = _subchannel.GetAddress(socketEndpoint); + if (address != null) + { + closeSocket = ShouldCloseSocket(socket, address, ref _initialSocketData, out checkException); + } + else + { + closeSocket = true; + } } } @@ -296,27 +304,27 @@ public async ValueTask GetStreamAsync(BalancerAddress address, Cancellat SocketConnectivitySubchannelTransportLog.CreatingStream(_logger, _subchannel.Id, address); Socket? socket = null; - BalancerAddress? socketAddress = null; + DnsEndPoint? socketEndPoint = null; List>? socketData = null; DateTime? socketCreatedTime = null; lock (Lock) { if (_initialSocket != null) { - var socketAddressMatch = Equals(_initialSocketAddress, address); + var socketEndPointMatch = Equals(_initialSocketEndPoint, address.EndPoint); socket = _initialSocket; - socketAddress = _initialSocketAddress; + socketEndPoint = _initialSocketEndPoint; socketData = _initialSocketData; socketCreatedTime = _initialSocketCreatedTime; _initialSocket = null; - _initialSocketAddress = null; + _initialSocketEndPoint = null; _initialSocketData = null; _initialSocketCreatedTime = null; // Double check the address matches the socket address and only use socket on match. // Not sure if this is possible in practice, but better safe than sorry. - if (!socketAddressMatch) + if (!socketEndPointMatch) { socket.Dispose(); socket = null; diff --git a/src/Grpc.Net.Client/Balancer/Subchannel.cs b/src/Grpc.Net.Client/Balancer/Subchannel.cs index 9bc5c1905..f21cd07c8 100644 --- a/src/Grpc.Net.Client/Balancer/Subchannel.cs +++ b/src/Grpc.Net.Client/Balancer/Subchannel.cs @@ -69,7 +69,7 @@ public sealed class Subchannel : IDisposable /// /// Gets the current connected address. /// - public BalancerAddress? CurrentAddress => _transport.CurrentAddress; + public BalancerAddress? CurrentAddress => GetAddress(_transport.CurrentEndPoint); /// /// Gets the metadata attributes. @@ -180,10 +180,13 @@ public void UpdateAddresses(IReadOnlyList addresses) case ConnectivityState.Ready: // Transport uses the subchannel lock but take copy in an abundance of caution. var currentAddress = CurrentAddress; - if (currentAddress != null && !_addresses.Contains(currentAddress)) + if (currentAddress != null) { - SubchannelLog.ConnectedAddressNotInUpdatedAddresses(_logger, Id, currentAddress); - requireReconnect = true; + if (!HasMatchingEndpoint(_addresses, currentAddress)) + { + SubchannelLog.ConnectedAddressNotInUpdatedAddresses(_logger, Id, currentAddress); + requireReconnect = true; + } } break; case ConnectivityState.Shutdown: @@ -409,6 +412,38 @@ internal void RaiseStateChanged(ConnectivityState state, Status status) } } + internal BalancerAddress? GetAddress(DnsEndPoint? endpoint) + { + if (endpoint != null) + { + lock (Lock) + { + foreach (var address in _addresses) + { + if (address.EndPoint.Equals(endpoint)) + { + return address; + } + } + } + } + + return null; + } + + private static bool HasMatchingEndpoint(List addresses, BalancerAddress currentAddress) + { + foreach (var a in addresses) + { + if (a.EndPoint.Equals(currentAddress.EndPoint)) + { + return true; + } + } + + return false; + } + /// public override string ToString() { diff --git a/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs b/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs index feca9e55e..0cef0e200 100644 --- a/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs +++ b/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -145,6 +145,14 @@ public override void UpdateChannelState(ChannelState state) // remaining in this collection at the end will be disposed. currentSubchannels.RemoveAt(i.Value); + // Check if address attributes have changed. If they have then update the subchannel address. + // The new subchannel address has the same endpoint so the connection isn't impacted. + if (!BalancerAddressEqualityComparer.Instance.Equals(address, newOrCurrentSubchannel.Address)) + { + newOrCurrentSubchannel.Subchannel.UpdateAddresses(new[] { address }); + newOrCurrentSubchannel = new AddressSubchannel(newOrCurrentSubchannel.Subchannel, address); + } + SubchannelLog.SubchannelPreserved(_logger, newOrCurrentSubchannel.Subchannel.Id, address); } else diff --git a/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs b/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs index 745ede8c8..692a875a4 100644 --- a/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs +++ b/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs @@ -328,7 +328,8 @@ public async Task HasSubchannels_ResolverRefresh_MatchingSubchannelUnchanged() resolver.UpdateAddresses(new List { new BalancerAddress("localhost", 80), - new BalancerAddress("localhost", 81) + new BalancerAddress("localhost", 81), + new BalancerAddress("localhost", 82) }); // Act @@ -340,31 +341,44 @@ public async Task HasSubchannels_ResolverRefresh_MatchingSubchannelUnchanged() await connectTask.DefaultTimeout(); var subchannels = channel.ConnectionManager.GetSubchannels(); - Assert.AreEqual(2, subchannels.Count); + Assert.AreEqual(3, subchannels.Count); Assert.AreEqual(1, subchannels[0]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 80), subchannels[0]._addresses[0].EndPoint); Assert.AreEqual(1, subchannels[1]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 81), subchannels[1]._addresses[0].EndPoint); + Assert.AreEqual(1, subchannels[2]._addresses.Count); + Assert.AreEqual(new DnsEndPoint("localhost", 82), subchannels[2]._addresses[0].EndPoint); - // Preserved because port 81 is in both refresh results - var preservedSubchannel = subchannels[1]; + // Preserved because port 81, 82 is in both refresh results + var preservedSubchannel1 = subchannels[1]; + var preservedSubchannel2 = subchannels[2]; + + var address2 = new BalancerAddress("localhost", 82); + address2.Attributes.Set(new BalancerAttributesKey("test"), 1); resolver.UpdateAddresses(new List { new BalancerAddress("localhost", 81), - new BalancerAddress("localhost", 82) + address2, + new BalancerAddress("localhost", 83) }); subchannels = channel.ConnectionManager.GetSubchannels(); - Assert.AreEqual(2, subchannels.Count); + Assert.AreEqual(3, subchannels.Count); Assert.AreEqual(1, subchannels[0]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 81), subchannels[0]._addresses[0].EndPoint); Assert.AreEqual(1, subchannels[1]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 82), subchannels[1]._addresses[0].EndPoint); + Assert.AreEqual(1, subchannels[2]._addresses.Count); + Assert.AreEqual(new DnsEndPoint("localhost", 83), subchannels[2]._addresses[0].EndPoint); + + Assert.AreSame(preservedSubchannel1, subchannels[0]); + Assert.AreSame(preservedSubchannel2, subchannels[1]); - Assert.AreSame(preservedSubchannel, subchannels[0]); + // Test that the channel's address was updated with new attribute with new attributes. + Assert.AreSame(preservedSubchannel2.CurrentAddress, address2); } } #endif diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs b/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs index f831c8682..dbfde0599 100644 --- a/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs +++ b/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -38,7 +38,7 @@ internal class TestSubchannelTransport : ISubchannelTransport public Subchannel Subchannel { get; } - public BalancerAddress? CurrentAddress { get; private set; } + public DnsEndPoint? CurrentEndPoint { get; private set; } public TimeSpan? ConnectTimeout => _factory.ConnectTimeout; public Task TryConnectTask => _connectTcs.Task; @@ -68,7 +68,7 @@ public ValueTask GetStreamAsync(BalancerAddress address, CancellationTok public void Disconnect() { - CurrentAddress = null; + CurrentEndPoint = null; Subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected."); } @@ -82,7 +82,7 @@ public async { var (newState, connectResult) = await (_onTryConnect?.Invoke(context.CancellationToken) ?? Task.FromResult(new TryConnectResult(ConnectivityState.Ready))); - CurrentAddress = Subchannel._addresses[0]; + CurrentEndPoint = Subchannel._addresses[0].EndPoint; var newStatus = newState == ConnectivityState.TransientFailure ? new Status(StatusCode.Internal, "") : Status.DefaultSuccess; Subchannel.UpdateConnectivityState(newState, newStatus);