diff --git a/.gitignore b/.gitignore index f174fb1..aad7dc7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ nunit-agent* test-output.log TestResults.xml TestResult.xml -Tests/coverage.xml test.sh *.VisualState.xml .vscode @@ -121,7 +120,7 @@ projects/Unit*/TestResult.xml .*.sw? # tests -Tests/coverage.* +coverage* # docs docs/temp/ diff --git a/RabbitMQ.AMQP.Client/IConnection.cs b/RabbitMQ.AMQP.Client/IConnection.cs index 5f252aa..20343ac 100644 --- a/RabbitMQ.AMQP.Client/IConnection.cs +++ b/RabbitMQ.AMQP.Client/IConnection.cs @@ -2,7 +2,16 @@ namespace RabbitMQ.AMQP.Client; -public class ConnectionException(string? message) : Exception(message); +public class ConnectionException : Exception +{ + public ConnectionException(string message) : base(message) + { + } + + public ConnectionException(string message, Exception innerException) : base(message, innerException) + { + } +} public interface IConnection : ILifeCycle { @@ -12,6 +21,5 @@ public interface IConnection : ILifeCycle IConsumerBuilder ConsumerBuilder(); - public ReadOnlyCollection GetPublishers(); } diff --git a/RabbitMQ.AMQP.Client/IConnectionSettings.cs b/RabbitMQ.AMQP.Client/IConnectionSettings.cs index ccf68ff..2e95427 100644 --- a/RabbitMQ.AMQP.Client/IConnectionSettings.cs +++ b/RabbitMQ.AMQP.Client/IConnectionSettings.cs @@ -1,19 +1,50 @@ +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; + namespace RabbitMQ.AMQP.Client; -public interface IConnectionSettings +public interface IConnectionSettings : IEquatable { - string Host(); - - int Port(); - - string VirtualHost(); - - - string User(); - - string Password(); - - string Scheme(); + string Host { get; } + int Port { get; } + string VirtualHost { get; } + string User { get; } + string Password { get; } + string Scheme { get; } + string ConnectionName { get; } + string Path { get; } + bool UseSsl { get; } + ITlsSettings? TlsSettings { get; } +} - string ConnectionName(); +/// +/// Contains the TLS/SSL settings for a connection. +/// +public interface ITlsSettings +{ + /// + /// Client certificates to use for mutual authentication. + /// + X509CertificateCollection ClientCertificates { get; } + + /// + /// Supported protocols to use. + /// + SslProtocols Protocols { get; } + + /// + /// Specifies whether certificate revocation should be performed during handshake. + /// + bool CheckCertificateRevocation { get; } + + /// + /// Gets or sets a certificate validation callback to validate remote certificate. + /// + RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get; } + + /// + /// Gets or sets a local certificate selection callback to select the certificate which should be used for authentication. + /// + LocalCertificateSelectionCallback? LocalCertificateSelectionCallback { get; } } diff --git a/RabbitMQ.AMQP.Client/Impl/AbstractLifeCycle.cs b/RabbitMQ.AMQP.Client/Impl/AbstractLifeCycle.cs index bbb55e3..91dac94 100644 --- a/RabbitMQ.AMQP.Client/Impl/AbstractLifeCycle.cs +++ b/RabbitMQ.AMQP.Client/Impl/AbstractLifeCycle.cs @@ -33,7 +33,7 @@ protected void OnNewStatus(State newState, Error? error) return; } - var oldStatus = State; + State oldStatus = State; State = newState; ChangeState?.Invoke(this, oldStatus, newState, error); } diff --git a/RabbitMQ.AMQP.Client/Impl/AmqpConnection.cs b/RabbitMQ.AMQP.Client/Impl/AmqpConnection.cs index d6557f5..cbe768a 100644 --- a/RabbitMQ.AMQP.Client/Impl/AmqpConnection.cs +++ b/RabbitMQ.AMQP.Client/Impl/AmqpConnection.cs @@ -39,7 +39,6 @@ public class AmqpConnection : AbstractLifeCycle, IConnection private const string ConnectionNotRecoveredMessage = "Connection not recovered"; private readonly SemaphoreSlim _semaphoreClose = new(1, 1); - // The native AMQP.Net Lite connection private Connection? _nativeConnection; @@ -71,7 +70,6 @@ private void ChangeConsumersStatus(State state, Error? error) } } - private async Task ReconnectEntities() { await ReconnectPublishers().ConfigureAwait(false); @@ -102,7 +100,6 @@ private async Task ReconnectConsumers() // TODO: Implement the semaphore to avoid multiple connections // private readonly SemaphoreSlim _semaphore = new(1, 1); - /// /// Publishers contains all the publishers created by the connection. /// Each connection can have multiple publishers. @@ -113,7 +110,6 @@ private async Task ReconnectConsumers() internal ConcurrentDictionary Consumers { get; } = new(); - public ReadOnlyCollection GetPublishers() { return Publishers.Values.ToList().AsReadOnly(); @@ -179,14 +175,18 @@ public IConsumerBuilder ConsumerBuilder() return new AmqpConsumerBuilder(this); } - protected override Task OpenAsync() + protected override async Task OpenAsync() { - EnsureConnection(); - return base.OpenAsync(); + await EnsureConnection() + .ConfigureAwait(false); + await base.OpenAsync() + .ConfigureAwait(false); } - private void EnsureConnection() + private async Task EnsureConnection() { + // TODO: do this! + // await _semaphore.WaitAsync(); try { if (_nativeConnection is { IsClosed: false }) @@ -196,22 +196,53 @@ private void EnsureConnection() var open = new Open { - HostName = $"vhost:{_connectionSettings.VirtualHost()}", + HostName = $"vhost:{_connectionSettings.VirtualHost}", Properties = new Fields() { - [new Symbol("connection_name")] = _connectionSettings.ConnectionName(), + [new Symbol("connection_name")] = _connectionSettings.ConnectionName, } }; - var manualReset = new ManualResetEvent(false); - _nativeConnection = new Connection(_connectionSettings.Address, null, open, (connection, open1) => + void onOpened(Amqp.IConnection connection, Open open1) { - manualReset.Set(); Trace.WriteLine(TraceLevel.Verbose, $"Connection opened. Info: {ToString()}"); OnNewStatus(State.Open, null); - }); + } + + var cf = new ConnectionFactory(); + + if (_connectionSettings.UseSsl && _connectionSettings.TlsSettings is not null) + { + cf.SSL.Protocols = _connectionSettings.TlsSettings.Protocols; + cf.SSL.CheckCertificateRevocation = _connectionSettings.TlsSettings.CheckCertificateRevocation; + + if (_connectionSettings.TlsSettings.ClientCertificates.Count > 0) + { + cf.SSL.ClientCertificates = _connectionSettings.TlsSettings.ClientCertificates; + } + + if (_connectionSettings.TlsSettings.LocalCertificateSelectionCallback is not null) + { + cf.SSL.LocalCertificateSelectionCallback = _connectionSettings.TlsSettings.LocalCertificateSelectionCallback; + } + + if (_connectionSettings.TlsSettings.RemoteCertificateValidationCallback is not null) + { + cf.SSL.RemoteCertificateValidationCallback = _connectionSettings.TlsSettings.RemoteCertificateValidationCallback; + } + } + + try + { + _nativeConnection = await cf.CreateAsync(_connectionSettings.Address, open: open, onOpened: onOpened) + .ConfigureAwait(false); + } + catch (Exception ex) + { + throw new ConnectionException( + $"Connection failed. Info: {ToString()}", ex); + } - manualReset.WaitOne(TimeSpan.FromSeconds(5)); if (_nativeConnection.IsClosed) { throw new ConnectionException( @@ -294,7 +325,8 @@ await Task.Run(async () => await Task.Delay(TimeSpan.FromMilliseconds(next)) .ConfigureAwait(false); - EnsureConnection(); + await EnsureConnection() + .ConfigureAwait(false); connected = true; } catch (Exception e) diff --git a/RabbitMQ.AMQP.Client/Impl/ConnectionSettings.cs b/RabbitMQ.AMQP.Client/Impl/ConnectionSettings.cs index 0bc1c9f..2a9c18d 100644 --- a/RabbitMQ.AMQP.Client/Impl/ConnectionSettings.cs +++ b/RabbitMQ.AMQP.Client/Impl/ConnectionSettings.cs @@ -1,4 +1,7 @@ -using Amqp; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using Amqp; namespace RabbitMQ.AMQP.Client.Impl; @@ -6,7 +9,7 @@ public class ConnectionSettingBuilder { // TODO: maybe add the event "LifeCycle" to the builder private string _host = "localhost"; - private int _port = 5672; + private int _port = -1; // Note: -1 means use the defalt for the scheme private string _user = "guest"; private string _password = "guest"; private string _scheme = "AMQP"; @@ -92,102 +95,117 @@ public ConnectionSettings Build() // public class ConnectionSettings : IConnectionSettings { - internal Address Address { get; } - + private readonly Address _address; private readonly string _connectionName = ""; private readonly string _virtualHost = "/"; + private readonly ITlsSettings? _tlsSettings; - - public ConnectionSettings(string address) + public ConnectionSettings(string address, ITlsSettings? tlsSettings = null) { - Address = new Address(address); + _address = new Address(address); + _tlsSettings = tlsSettings; + + if (_address.UseSsl && _tlsSettings == null) + { + _tlsSettings = new TlsSettings(); + } } public ConnectionSettings(string host, int port, - string user, - string password, - string virtualHost, string scheme, string connectionName) + string user, string password, + string virtualHost, string scheme, string connectionName, + ITlsSettings? tlsSettings = null) { - Address = new Address(host, port, user, password, "/", scheme); + _address = new Address(host: host, port: port, + user: user, password: password, + path: "/", scheme: scheme); _connectionName = connectionName; _virtualHost = virtualHost; - } + _tlsSettings = tlsSettings; - public string Host() - { - return Address.Host; + if (_address.UseSsl && _tlsSettings == null) + { + _tlsSettings = new TlsSettings(); + } } + public string Host => _address.Host; + public int Port => _address.Port; + public string VirtualHost => _virtualHost; + public string User => _address.User; + public string Password => _address.Password; + public string Scheme => _address.Scheme; + public string ConnectionName => _connectionName; + public string Path => _address.Path; + public bool UseSsl => _address.UseSsl; - public int Port() - { - return Address.Port; - } - + public ITlsSettings? TlsSettings => _tlsSettings; - public string VirtualHost() + public override string ToString() { - return _virtualHost; + return + $"Address" + + $"host='{_address.Host}', " + + $"port={_address.Port}, VirtualHost='{_virtualHost}', path='{_address.Path}', " + + $"username='{_address.User}', ConnectionName='{_connectionName}'"; } - public string User() + public override bool Equals(object? obj) { - return Address.User; - } + if (obj is null) + { + return false; + } + if (obj is ConnectionSettings address) + { + return _address.Host == address._address.Host && + _address.Port == address._address.Port && + _address.Path == address._address.Path && + _address.User == address._address.User && + _address.Password == address._address.Password && + _address.Scheme == address._address.Scheme; + } - public string Password() - { - return Address.Password; + return false; } - - public string Scheme() + protected bool Equals(ConnectionSettings other) { - return Address.Scheme; - } + if (other is null) + { + return false; + } - public string ConnectionName() - { - return _connectionName; + return _address.Equals(other._address); } - public override string ToString() + public override int GetHashCode() { - var i = - $"Address" + - $"host='{Address.Host}', " + - $"port={Address.Port}, VirtualHost='{_virtualHost}', path='{Address.Path}', " + - $"username='{Address.User}', ConnectionName='{_connectionName}'"; - return i; + return _address.GetHashCode(); } - - public override bool Equals(object? obj) + public bool Equals(IConnectionSettings? other) { - if (obj == null || GetType() != obj.GetType()) + if (other is null) { return false; } - var address = (ConnectionSettings)obj; - return Address.Host == address.Address.Host && - Address.Port == address.Address.Port && - Address.Path == address.Address.Path && - Address.User == address.Address.User && - Address.Password == address.Address.Password && - Address.Scheme == address.Address.Scheme; - } + if (other is IConnectionSettings connectionSettings) + { + return _address.Host == connectionSettings.Host && + _address.Port == connectionSettings.Port && + _address.Path == connectionSettings.Path && + _address.User == connectionSettings.User && + _address.Password == connectionSettings.Password && + _address.Scheme == connectionSettings.Scheme; + } - protected bool Equals(ConnectionSettings other) - { - return Address.Equals(other.Address); + return false; } - public override int GetHashCode() - { - return Address.GetHashCode(); - } + internal Address Address => _address; public RecoveryConfiguration RecoveryConfiguration { get; set; } = RecoveryConfiguration.Create(); } @@ -240,7 +258,6 @@ public IBackOffDelayPolicy GetBackOffDelayPolicy() return _backOffDelayPolicy; } - public IRecoveryConfiguration Topology(bool activated) { _topology = activated; @@ -309,3 +326,44 @@ public override string ToString() return $"BackOffDelayPolicy{{ Attempt={_attempt}, TotalAttempt={_totalAttempt}, IsActive={IsActive} }}"; } } + +public class TlsSettings : ITlsSettings +{ + internal const SslProtocols DefaultSslProtocols = SslProtocols.None; + + private readonly SslProtocols _protocols; + private readonly X509CertificateCollection _clientCertificates; + private readonly bool _checkCertificateRevocation = false; + private readonly RemoteCertificateValidationCallback? _remoteCertificateValidationCallback; + private readonly LocalCertificateSelectionCallback? _localCertificateSelectionCallback; + + public TlsSettings() : this(DefaultSslProtocols) + { + } + + public TlsSettings(SslProtocols protocols) + { + _protocols = protocols; + _clientCertificates = new X509CertificateCollection(); + _remoteCertificateValidationCallback = trustEverythingCertValidationCallback; + _localCertificateSelectionCallback = null; + } + + public SslProtocols Protocols => _protocols; + + public X509CertificateCollection ClientCertificates => _clientCertificates; + + public bool CheckCertificateRevocation => _checkCertificateRevocation; + + public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback + => _remoteCertificateValidationCallback; + + public LocalCertificateSelectionCallback? LocalCertificateSelectionCallback + => _localCertificateSelectionCallback; + + private static bool trustEverythingCertValidationCallback(object sender, X509Certificate? certificate, + X509Chain? chain, SslPolicyErrors sslPolicyErrors) + { + return true; + } +} diff --git a/Tests/ConnectionTests.cs b/Tests/ConnectionTests.cs index ece5de3..61b3978 100644 --- a/Tests/ConnectionTests.cs +++ b/Tests/ConnectionTests.cs @@ -1,5 +1,4 @@ -using System.Net.Sockets; -using RabbitMQ.AMQP.Client; +using RabbitMQ.AMQP.Client; using RabbitMQ.AMQP.Client.Impl; namespace Tests; @@ -14,12 +13,12 @@ public void ValidateAddress() { ConnectionSettings connectionSettings = new("localhost", 5672, "guest-user", "guest-password", "vhost_1", "amqp1", "connection_name"); - Assert.Equal("localhost", connectionSettings.Host()); - Assert.Equal(5672, connectionSettings.Port()); - Assert.Equal("guest-user", connectionSettings.User()); - Assert.Equal("guest-password", connectionSettings.Password()); - Assert.Equal("vhost_1", connectionSettings.VirtualHost()); - Assert.Equal("amqp1", connectionSettings.Scheme()); + Assert.Equal("localhost", connectionSettings.Host); + Assert.Equal(5672, connectionSettings.Port); + Assert.Equal("guest-user", connectionSettings.User); + Assert.Equal("guest-password", connectionSettings.Password); + Assert.Equal("vhost_1", connectionSettings.VirtualHost); + Assert.Equal("amqp1", connectionSettings.Scheme); ConnectionSettings second = new("localhost", 5672, "guest-user", "guest-password", "path/", "amqp1", "connection_name"); @@ -35,21 +34,62 @@ public void ValidateAddress() [Fact] public void ValidateAddressBuilder() { - var address = ConnectionSettingBuilder.Create() + ConnectionSettings connectionSettings = ConnectionSettingBuilder.Create() .Host("localhost") - .Port(5672) .VirtualHost("v1") .User("guest-t") .Password("guest-w") - .Scheme("amqp1") + .Scheme("AMQP") .Build(); - Assert.Equal("localhost", address.Host()); - Assert.Equal(5672, address.Port()); - Assert.Equal("guest-t", address.User()); - Assert.Equal("guest-w", address.Password()); - Assert.Equal("v1", address.VirtualHost()); - Assert.Equal("amqp1", address.Scheme()); + Assert.Equal("localhost", connectionSettings.Host); + Assert.Equal(5672, connectionSettings.Port); + Assert.Equal("guest-t", connectionSettings.User); + Assert.Equal("guest-w", connectionSettings.Password); + Assert.Equal("v1", connectionSettings.VirtualHost); + Assert.Equal("AMQP", connectionSettings.Scheme); + } + + [Fact] + public void ValidateBuilderWithSslOptions() + { + ConnectionSettings connectionSettings = ConnectionSettingBuilder.Create() + .Host("localhost") + .VirtualHost("v1") + .User("guest-t") + .Password("guest-w") + .Scheme("amqps") + .Build(); + + Assert.True(connectionSettings.UseSsl); + Assert.Equal("localhost", connectionSettings.Host); + Assert.Equal(5671, connectionSettings.Port); + Assert.Equal("guest-t", connectionSettings.User); + Assert.Equal("guest-w", connectionSettings.Password); + Assert.Equal("v1", connectionSettings.VirtualHost); + Assert.Equal("amqps", connectionSettings.Scheme); + } + + [Fact] + public async Task ConnectUsingTlsAndUserPassword() + { + ConnectionSettings connectionSettings = ConnectionSettingBuilder.Create() + .Host("localhost") + .Scheme("amqps") + .Build(); + + Assert.True(connectionSettings.UseSsl); + Assert.Equal("localhost", connectionSettings.Host); + Assert.Equal(5671, connectionSettings.Port); + Assert.Equal("guest", connectionSettings.User); + Assert.Equal("guest", connectionSettings.Password); + Assert.Equal("/", connectionSettings.VirtualHost); + Assert.Equal("amqps", connectionSettings.Scheme); + + IConnection connection = await AmqpConnection.CreateAsync(connectionSettings); + Assert.Equal(State.Open, connection.State); + await connection.CloseAsync(); + Assert.Equal(State.Closed, connection.State); } [Fact] @@ -58,7 +98,8 @@ public async Task RaiseErrorsIfTheParametersAreNotValid() await Assert.ThrowsAsync(async () => await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().VirtualHost("wrong_vhost").Build())); - await Assert.ThrowsAnyAsync(async () => + // TODO check inner exception is a SocketException + await Assert.ThrowsAnyAsync(async () => await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().Host("wrong_host").Build())); await Assert.ThrowsAsync(async () => @@ -67,17 +108,18 @@ await Assert.ThrowsAsync(async () => await Assert.ThrowsAsync(async () => await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().User("wrong_user").Build())); - await Assert.ThrowsAnyAsync(async () => + // TODO check inner exception is a SocketException + await Assert.ThrowsAnyAsync(async () => await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().Port(1234).Build())); } [Fact] public async Task ThrowAmqpClosedExceptionWhenItemIsClosed() { - var connection = await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().Build()); - var management = connection.Management(); + IConnection connection = await AmqpConnection.CreateAsync(ConnectionSettingBuilder.Create().Build()); + IManagement management = connection.Management(); await management.Queue().Name("ThrowAmqpClosedExceptionWhenItemIsClosed").Declare(); - var publisher = connection.PublisherBuilder().Queue("ThrowAmqpClosedExceptionWhenItemIsClosed").Build(); + IPublisher publisher = connection.PublisherBuilder().Queue("ThrowAmqpClosedExceptionWhenItemIsClosed").Build(); await publisher.CloseAsync(); await Assert.ThrowsAsync(async () => await publisher.Publish(new AmqpMessage("Hello wold!"), (message, descriptor) =>