Skip to content

Commit

Permalink
Change subchannel BalancerAddress when attributes change
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK committed Aug 15, 2023
1 parent 5a8e2ba commit 2ed833e
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 56 deletions.
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Balancer/BalancerAddress.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -30,7 +30,7 @@ namespace Grpc.Net.Client.Balancer;
/// </summary>
public sealed class BalancerAddress
{
private BalancerAttributes? _attributes;
internal BalancerAttributes? _attributes;

/// <summary>
/// Initializes a new instance of the <see cref="BalancerAddress"/> class with the specified <see cref="DnsEndPoint"/>.
Expand Down
57 changes: 45 additions & 12 deletions src/Grpc.Net.Client/Balancer/BalancerAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ public sealed class BalancerAttributes : IDictionary<string, object?>, IReadOnly
/// <summary>
/// Gets a read-only collection of metadata attributes.
/// </summary>
public static readonly BalancerAttributes Empty = new BalancerAttributes(new ReadOnlyDictionary<string, object?>(new Dictionary<string, object?>()));
public static readonly BalancerAttributes Empty = new BalancerAttributes(new Dictionary<string, object?>(), readOnly: true);

private readonly IDictionary<string, object?> _attributes;
internal readonly Dictionary<string, object?> _attributes;
private readonly bool _readOnly;

/// <summary>
/// Initializes a new instance of the <see cref="BalancerAttributes"/> class.
/// </summary>
public BalancerAttributes() : this(new Dictionary<string, object?>())
public BalancerAttributes() : this(new Dictionary<string, object?>(), readOnly: false)
{
}

private BalancerAttributes(IDictionary<string, object?> attributes)
private BalancerAttributes(Dictionary<string, object?> attributes, bool readOnly)
{
_attributes = attributes;
_readOnly = readOnly;
}

object? IDictionary<string, object?>.this[string key]
Expand All @@ -62,28 +64,49 @@ private BalancerAttributes(IDictionary<string, object?> attributes)
}
set
{
ValidateReadOnly();
_attributes[key] = value;
}
}

ICollection<string> IDictionary<string, object?>.Keys => _attributes.Keys;
ICollection<object?> IDictionary<string, object?>.Values => _attributes.Values;
int ICollection<KeyValuePair<string, object?>>.Count => _attributes.Count;
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => _attributes.IsReadOnly;
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => _readOnly ? true : ((ICollection<KeyValuePair<string, object?>>)_attributes).IsReadOnly;
IEnumerable<string> IReadOnlyDictionary<string, object?>.Keys => _attributes.Keys;
IEnumerable<object?> IReadOnlyDictionary<string, object?>.Values => _attributes.Values;
int IReadOnlyCollection<KeyValuePair<string, object?>>.Count => _attributes.Count;
object? IReadOnlyDictionary<string, object?>.this[string key] => _attributes[key];
void IDictionary<string, object?>.Add(string key, object? value) => _attributes.Add(key, value);
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) => _attributes.Add(item);
void ICollection<KeyValuePair<string, object?>>.Clear() => _attributes.Clear();
void IDictionary<string, object?>.Add(string key, object? value)
{
ValidateReadOnly();
_attributes.Add(key, value);
}
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item)
{
ValidateReadOnly();
((ICollection<KeyValuePair<string, object?>>)_attributes).Add(item);
}
void ICollection<KeyValuePair<string, object?>>.Clear()
{
ValidateReadOnly();
_attributes.Clear();
}
bool ICollection<KeyValuePair<string, object?>>.Contains(KeyValuePair<string, object?> item) => _attributes.Contains(item);
bool IDictionary<string, object?>.ContainsKey(string key) => _attributes.ContainsKey(key);
void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) => _attributes.CopyTo(array, arrayIndex);
void ICollection<KeyValuePair<string, object?>>.CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) => ((ICollection<KeyValuePair<string, object?>>)_attributes).CopyTo(array, arrayIndex);
IEnumerator<KeyValuePair<string, object?>> IEnumerable<KeyValuePair<string, object?>>.GetEnumerator() => _attributes.GetEnumerator();
IEnumerator System.Collections.IEnumerable.GetEnumerator() => ((System.Collections.IEnumerable)_attributes).GetEnumerator();
bool IDictionary<string, object?>.Remove(string key) => _attributes.Remove(key);
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item) => _attributes.Remove(item);
bool IDictionary<string, object?>.Remove(string key)
{
ValidateReadOnly();
return _attributes.Remove(key);
}
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item)
{
ValidateReadOnly();
return ((ICollection<KeyValuePair<string, object?>>)_attributes).Remove(item);
}
bool IDictionary<string, object?>.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value);
bool IReadOnlyDictionary<string, object?>.ContainsKey(string key) => _attributes.ContainsKey(key);
bool IReadOnlyDictionary<string, object?>.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value);
Expand Down Expand Up @@ -121,6 +144,7 @@ public bool TryGetValue<TValue>(BalancerAttributesKey<TValue> key, [MaybeNullWhe
/// <param name="value">The value.</param>
public void Set<TValue>(BalancerAttributesKey<TValue> key, TValue value)
{
ValidateReadOnly();
_attributes[key.Key] = value;
}

Expand All @@ -135,10 +159,19 @@ public void Set<TValue>(BalancerAttributesKey<TValue> key, TValue value)
/// </returns>
public bool Remove<TValue>(BalancerAttributesKey<TValue> key)
{
ValidateReadOnly();
return _attributes.Remove(key.Key);
}

internal string DebuggerToString()
private void ValidateReadOnly()
{
if (_readOnly)
{
throw new NotSupportedException("Collection is read-only.");
}
}

private string DebuggerToString()
{
return $"Count = {_attributes.Count}";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -44,6 +44,46 @@ public bool Equals(BalancerAddress? x, BalancerAddress? y)
return false;
}

var xAttributes = x._attributes?._attributes;
var yAttributes = y._attributes?._attributes;
if (!AttributesEqual(xAttributes, yAttributes))
{
return false;
}

return true;
}

private bool AttributesEqual(Dictionary<string, object?>? x, Dictionary<string, object?>? y)
{
if (x == null && y == null)
{
return true;
}

if (x == null || y == null)
{
return false;
}

if (x.Count != y.Count)
{
return false;
}

foreach (var kvp in x)
{
if (!y.TryGetValue(kvp.Key, out var value))
{
return false;
}

if (!Equals(kvp.Value, value))
{
return false;
}
}

return true;
}

Expand Down
5 changes: 3 additions & 2 deletions src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -17,6 +17,7 @@
#endregion

#if SUPPORT_LOAD_BALANCING
using System.Net;
using Grpc.Shared;

namespace Grpc.Net.Client.Balancer.Internal;
Expand All @@ -28,7 +29,7 @@ namespace Grpc.Net.Client.Balancer.Internal;
/// </summary>
internal interface ISubchannelTransport : IDisposable
{
BalancerAddress? CurrentAddress { get; }
DnsEndPoint? CurrentEndPoint { get; }
TimeSpan? ConnectTimeout { get; }

#if NET5_0_OR_GREATER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ namespace Grpc.Net.Client.Balancer.Internal;
internal class PassiveSubchannelTransport : ISubchannelTransport, IDisposable
{
private readonly Subchannel _subchannel;
private BalancerAddress? _currentAddress;
private DnsEndPoint? _currentEndPoint;

public PassiveSubchannelTransport(Subchannel subchannel)
{
_subchannel = subchannel;
}

public BalancerAddress? CurrentAddress => _currentAddress;
public DnsEndPoint? CurrentEndPoint => _currentEndPoint;
public TimeSpan? ConnectTimeout { get; }

public void Disconnect()
{
_currentAddress = null;
_currentEndPoint = null;
_subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected.");
}

Expand All @@ -60,12 +60,12 @@ public void Disconnect()
TryConnectAsync(ConnectContext context)
{
Debug.Assert(_subchannel._addresses.Count == 1);
Debug.Assert(CurrentAddress == null);
Debug.Assert(CurrentEndPoint == null);

var currentAddress = _subchannel._addresses[0];

_subchannel.UpdateConnectivityState(ConnectivityState.Connecting, "Passively connecting.");
_currentAddress = currentAddress;
_currentEndPoint = currentAddress.EndPoint;
_subchannel.UpdateConnectivityState(ConnectivityState.Ready, "Passively connected.");

#if !NETSTANDARD2_0 && !NET462
Expand All @@ -77,7 +77,7 @@ public void Disconnect()

public void Dispose()
{
_currentAddress = null;
_currentEndPoint = null;
}

#if NET5_0_OR_GREATER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi

private int _lastEndPointIndex;
internal Socket? _initialSocket;
private BalancerAddress? _initialSocketAddress;
private DnsEndPoint? _initialSocketEndPoint;
private List<ReadOnlyMemory<byte>>? _initialSocketData;
private DateTime? _initialSocketCreatedTime;
private bool _disposed;
private BalancerAddress? _currentAddress;
private DnsEndPoint? _currentEndPoint;

public SocketConnectivitySubchannelTransport(
Subchannel subchannel,
Expand All @@ -88,7 +88,7 @@ public SocketConnectivitySubchannelTransport(
}

private object Lock => _subchannel.Lock;
public BalancerAddress? CurrentAddress => _currentAddress;
public DnsEndPoint? CurrentEndPoint => _currentEndPoint;
public TimeSpan? ConnectTimeout { get; }

// For testing. Take a copy under lock for thread-safety.
Expand Down Expand Up @@ -127,16 +127,16 @@ private void DisconnectUnsynchronized()

_initialSocket?.Dispose();
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketEndPoint = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;
_lastEndPointIndex = 0;
_currentAddress = null;
_currentEndPoint = null;
}

public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
{
Debug.Assert(CurrentAddress == null);
Debug.Assert(CurrentEndPoint == null);

// Addresses could change while connecting. Make a copy of the subchannel's addresses.
var addresses = _subchannel.GetAddresses();
Expand All @@ -162,10 +162,10 @@ public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)

lock (Lock)
{
_currentAddress = currentAddress;
_currentEndPoint = currentAddress.EndPoint;
_lastEndPointIndex = currentIndex;
_initialSocket = socket;
_initialSocketAddress = currentAddress;
_initialSocketEndPoint = currentAddress.EndPoint;
_initialSocketData = null;
_initialSocketCreatedTime = DateTime.UtcNow;

Expand Down Expand Up @@ -240,20 +240,28 @@ private void OnCheckSocketConnection(object? state)
try
{
Socket? socket;
BalancerAddress? socketAddress;
DnsEndPoint? socketEndpoint;
var closeSocket = false;
Exception? checkException = null;

lock (Lock)
{
socket = _initialSocket;
socketAddress = _initialSocketAddress;
socketEndpoint = _initialSocketEndPoint;

if (socket != null)
{
CompatibilityHelpers.Assert(socketAddress != null);
CompatibilityHelpers.Assert(socketEndpoint != null);

closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException);
var address = _subchannel.GetAddress(socketEndpoint);
if (address != null)
{
closeSocket = ShouldCloseSocket(socket, address, ref _initialSocketData, out checkException);
}
else
{
closeSocket = true;
}
}
}

Expand Down Expand Up @@ -296,27 +304,27 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
SocketConnectivitySubchannelTransportLog.CreatingStream(_logger, _subchannel.Id, address);

Socket? socket = null;
BalancerAddress? socketAddress = null;
DnsEndPoint? socketEndPoint = null;
List<ReadOnlyMemory<byte>>? socketData = null;
DateTime? socketCreatedTime = null;
lock (Lock)
{
if (_initialSocket != null)
{
var socketAddressMatch = Equals(_initialSocketAddress, address);
var socketEndPointMatch = Equals(_initialSocketEndPoint, address.EndPoint);

socket = _initialSocket;
socketAddress = _initialSocketAddress;
socketEndPoint = _initialSocketEndPoint;
socketData = _initialSocketData;
socketCreatedTime = _initialSocketCreatedTime;
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketEndPoint = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;

// Double check the address matches the socket address and only use socket on match.
// Not sure if this is possible in practice, but better safe than sorry.
if (!socketAddressMatch)
if (!socketEndPointMatch)
{
socket.Dispose();
socket = null;
Expand Down
Loading

0 comments on commit 2ed833e

Please sign in to comment.