diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 3b565f1b71..8731ea2071 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -743,7 +743,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { @usableFromInline internal var _channelOptions: ChannelOptions.Storage private var connectTimeout: TimeAmount = TimeAmount.seconds(10) - private var resolver: Optional + private var resolver: Optional private var bindTarget: Optional private var enableMPTCP: Bool @@ -838,12 +838,22 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { return self } + /// Specifies the `NIOStreamingResolver` to use or `nil` if the default should be used. + /// + /// - parameters: + /// - resolver: The resolver that will be used during the connection attempt. + public func resolver(_ resolver: NIOStreamingResolver?) -> Self { + self.resolver = resolver + return self + } + /// Specifies the `Resolver` to use or `nil` if the default should be used. /// /// - parameters: /// - resolver: The resolver that will be used during the connection attempt. + @available(*, deprecated) public func resolver(_ resolver: Resolver?) -> Self { - self.resolver = resolver + self.resolver = resolver.map(NIOResolverToStreamingResolver.init(resolver:)) return self } @@ -897,9 +907,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(host: String, port: Int) -> EventLoopFuture { let loop = self.group.next() - let resolver = self.resolver ?? GetaddrinfoResolver(loop: loop, - aiSocktype: .stream, - aiProtocol: .tcp) + let resolver = self.resolver ?? GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) let connector = HappyEyeballsConnector(resolver: resolver, loop: loop, host: host, @@ -1226,7 +1234,6 @@ extension ClientBootstrap { postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture ) async throws -> PostRegistrationTransformationResult { let resolver = self.resolver ?? GetaddrinfoResolver( - loop: eventLoop, aiSocktype: .stream, aiProtocol: .tcp ) diff --git a/Sources/NIOPosix/GetaddrinfoResolver.swift b/Sources/NIOPosix/GetaddrinfoResolver.swift index 633e91d0f2..caab775ad7 100644 --- a/Sources/NIOPosix/GetaddrinfoResolver.swift +++ b/Sources/NIOPosix/GetaddrinfoResolver.swift @@ -48,54 +48,46 @@ import struct WinSDK.SOCKADDR_IN6 let offloadQueueTSV = ThreadSpecificVariable() -internal class GetaddrinfoResolver: Resolver { - private let v4Future: EventLoopPromise<[SocketAddress]> - private let v6Future: EventLoopPromise<[SocketAddress]> +internal class GetaddrinfoResolver: NIOStreamingResolver { private let aiSocktype: NIOBSDSocket.SocketType private let aiProtocol: NIOBSDSocket.OptionLevel /// Create a new resolver. /// /// - parameters: - /// - loop: The `EventLoop` whose thread this resolver will block. /// - aiSocktype: The sock type to use as hint when calling getaddrinfo. /// - aiProtocol: the protocol to use as hint when calling getaddrinfo. - init(loop: EventLoop, aiSocktype: NIOBSDSocket.SocketType, - aiProtocol: NIOBSDSocket.OptionLevel) { - self.v4Future = loop.makePromise() - self.v6Future = loop.makePromise() + init(aiSocktype: NIOBSDSocket.SocketType, aiProtocol: NIOBSDSocket.OptionLevel) { self.aiSocktype = aiSocktype self.aiProtocol = aiProtocol } - /// Initiate a DNS A query for a given host. - /// - /// Due to the nature of `getaddrinfo`, we only actually call the function once, in the AAAA query. - /// That means this just returns the future for the A results, which in practice will always have been - /// satisfied by the time this function is called. + /// Start a name resolution for a given name. /// /// - parameters: - /// - host: The hostname to do an A lookup on. - /// - port: The port we'll be connecting to. - /// - returns: An `EventLoopFuture` that fires with the result of the lookup. - func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { - return v4Future.futureResult + /// - name: The name to resolve. + /// - destinationPort: The port we'll be connecting to. + /// - session: The resolution session object associated with the resolution. + func resolve(name: String, destinationPort: Int, session: NIONameResolutionSession) { + self.offloadQueue().async { + self.resolveBlocking(host: name, port: destinationPort, session: session) + } } - /// Initiate a DNS AAAA query for a given host. + /// Cancel an outstanding name resolution. /// - /// Due to the nature of `getaddrinfo`, we only actually call the function once, in this function. - /// That means this function call actually blocks: sorry! + /// This method is called whenever a name resolution that hasn't completed no longer has its + /// results needed. The resolver should, if possible, abort any outstanding queries and clean + /// up their state. + /// + /// This method is not guaranteed to terminate the outstanding queries. + /// + /// In the getaddrinfo case this is a no-op, as the resolver blocks. /// /// - parameters: - /// - host: The hostname to do an AAAA lookup on. - /// - port: The port we'll be connecting to. - /// - returns: An `EventLoopFuture` that fires with the result of the lookup. - func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { - self.offloadQueue().async { - self.resolveBlocking(host: host, port: port) - } - return v6Future.futureResult + /// - session: The resolution session object associated with the resolution. + func cancel(_ session: NIONameResolutionSession) { + return } private func offloadQueue() -> DispatchQueue { @@ -113,21 +105,13 @@ internal class GetaddrinfoResolver: Resolver { } } - /// Cancel all outstanding DNS queries. - /// - /// This method is called whenever queries that have not completed no longer have their - /// results needed. The resolver should, if possible, abort any outstanding queries and - /// clean up their state. - /// - /// In the getaddrinfo case this is a no-op, as the resolver blocks. - func cancelQueries() { } - /// Perform the DNS queries and record the result. /// /// - parameters: /// - host: The hostname to do the DNS queries on. /// - port: The port we'll be connecting to. - private func resolveBlocking(host: String, port: Int) { + /// - session: The resolution session object associated with the resolution. + private func resolveBlocking(host: String, port: Int, session: NIONameResolutionSession) { #if os(Windows) host.withCString(encodedAs: UTF16.self) { wszHost in String(port).withCString(encodedAs: UTF16.self) { wszPort in @@ -139,15 +123,15 @@ internal class GetaddrinfoResolver: Resolver { let iResult = GetAddrInfoW(wszHost, wszPort, &aiHints, &pResult) guard iResult == 0 else { - self.fail(SocketAddressError.unknown(host: host, port: port)) + self.fail(SocketAddressError.unknown(host: host, port: port), session: session) return } if let pResult = pResult { - self.parseAndPublishResults(pResult, host: host) + self.parseAndPublishResults(pResult, host: host, session: session) FreeAddrInfoW(pResult) } else { - self.fail(SocketAddressError.unsupported) + self.fail(SocketAddressError.unsupported, session: session) } } } @@ -158,16 +142,16 @@ internal class GetaddrinfoResolver: Resolver { hint.ai_socktype = self.aiSocktype.rawValue hint.ai_protocol = self.aiProtocol.rawValue guard getaddrinfo(host, String(port), &hint, &info) == 0 else { - self.fail(SocketAddressError.unknown(host: host, port: port)) + self.fail(SocketAddressError.unknown(host: host, port: port), session: session) return } if let info = info { - self.parseAndPublishResults(info, host: host) + self.parseAndPublishResults(info, host: host, session: session) freeaddrinfo(info) } else { /* this is odd, getaddrinfo returned NULL */ - self.fail(SocketAddressError.unsupported) + self.fail(SocketAddressError.unsupported, session: session) } #endif } @@ -177,15 +161,15 @@ internal class GetaddrinfoResolver: Resolver { /// - parameters: /// - info: The pointer to the first of the `addrinfo` structures in the list. /// - host: The hostname we resolved. + /// - session: The resolution session object associated with the resolution. #if os(Windows) internal typealias CAddrInfo = ADDRINFOW #else internal typealias CAddrInfo = addrinfo #endif - private func parseAndPublishResults(_ info: UnsafeMutablePointer, host: String) { - var v4Results: [SocketAddress] = [] - var v6Results: [SocketAddress] = [] + private func parseAndPublishResults(_ info: UnsafeMutablePointer, host: String, session: NIONameResolutionSession) { + var results: [SocketAddress] = [] var info: UnsafeMutablePointer = info while true { @@ -193,12 +177,12 @@ internal class GetaddrinfoResolver: Resolver { switch NIOBSDSocket.AddressFamily(rawValue: info.pointee.ai_family) { case .inet: // Force-unwrap must be safe, or libc did the wrong thing. - v4Results.append(.init(addressBytes!.load(as: sockaddr_in.self), host: host)) + results.append(.init(addressBytes!.load(as: sockaddr_in.self), host: host)) case .inet6: // Force-unwrap must be safe, or libc did the wrong thing. - v6Results.append(.init(addressBytes!.load(as: sockaddr_in6.self), host: host)) + results.append(.init(addressBytes!.load(as: sockaddr_in6.self), host: host)) default: - self.fail(SocketAddressError.unsupported) + self.fail(SocketAddressError.unsupported, session: session) return } @@ -209,16 +193,16 @@ internal class GetaddrinfoResolver: Resolver { info = nextInfo } - v6Future.succeed(v6Results) - v4Future.succeed(v4Results) + session.deliverResults(results) + session.resolutionComplete(.success(())) } /// Record an error and fail the lookup process. /// /// - parameters: /// - error: The error encountered during lookup. - private func fail(_ error: Error) { - self.v6Future.fail(error) - self.v4Future.fail(error) + /// - session: The resolution session object associated with the resolution. + private func fail(_ error: Error, session: NIONameResolutionSession) { + session.resolutionComplete(.failure(error)) } } diff --git a/Sources/NIOPosix/HappyEyeballs.swift b/Sources/NIOPosix/HappyEyeballs.swift index f58435d77f..87f107541a 100644 --- a/Sources/NIOPosix/HappyEyeballs.swift +++ b/Sources/NIOPosix/HappyEyeballs.swift @@ -55,11 +55,16 @@ public struct NIOConnectionError: Error { /// The port SwiftNIO was trying to connect to. public let port: Int + /// The error we encountered doing the name resolution, if any. + public fileprivate(set) var resolutionError: Error? = nil + /// The error we encountered doing the DNS A lookup, if any. - public fileprivate(set) var dnsAError: Error? = nil + @available(*, deprecated, renamed: "resolutionError") + public var dnsAError: Error? { resolutionError } /// The error we encountered doing the DNS AAAA lookup, if any. - public fileprivate(set) var dnsAAAAError: Error? = nil + @available(*, deprecated, renamed: "resolutionError") + public var dnsAAAAError: Error? { resolutionError } /// The errors we encountered during the connection attempts. public fileprivate(set) var connectionErrors: [SingleConnectionFailure] = [] @@ -72,7 +77,7 @@ public struct NIOConnectionError: Error { /// A simple iterator that manages iterating over the possible targets. /// -/// This iterator knows how to merge together the A and AAAA records in a sensible way: +/// This iterator knows how to merge together IPv4 and IPv6 addresses in a sensible way: /// specifically, it keeps track of what the last address family it emitted was, and emits the /// address of the opposite family next. private struct TargetIterator: IteratorProtocol { @@ -84,39 +89,44 @@ private struct TargetIterator: IteratorProtocol { } private var previousAddressFamily: AddressFamily = .v4 - private var aQueryResults: [SocketAddress] = [] - private var aaaaQueryResults: [SocketAddress] = [] - - mutating func aResultsAvailable(_ results: [SocketAddress]) { - aQueryResults.append(contentsOf: results) - } - - mutating func aaaaResultsAvailable(_ results: [SocketAddress]) { - aaaaQueryResults.append(contentsOf: results) + private var v4Results: [SocketAddress] = [] + private var v6Results: [SocketAddress] = [] + + mutating func resultsAvailable(_ results: [SocketAddress]) { + for result in results { + switch result.protocol { + case .inet: + v4Results.append(result) + case .inet6: + v6Results.append(result) + default: + break + } + } } mutating func next() -> Element? { switch previousAddressFamily { case .v4: - return popAAAA() ?? popA() + return popV6() ?? popV4() case .v6: - return popA() ?? popAAAA() + return popV4() ?? popV6() } } - private mutating func popA() -> SocketAddress? { - if aQueryResults.count > 0 { + private mutating func popV4() -> SocketAddress? { + if v4Results.count > 0 { previousAddressFamily = .v4 - return aQueryResults.removeFirst() + return v4Results.removeFirst() } return nil } - private mutating func popAAAA() -> SocketAddress? { - if aaaaQueryResults.count > 0 { + private mutating func popV6() -> SocketAddress? { + if v6Results.count > 0 { previousAddressFamily = .v6 - return aaaaQueryResults.removeFirst() + return v6Results.removeFirst() } return nil @@ -150,21 +160,17 @@ internal final class HappyEyeballsConnector { /// Initial state. No work outstanding. case idle - /// All name queries are currently outstanding. + /// No results have been returned yet. case resolving - /// The A query has resolved, but the AAAA query is outstanding and the + /// The resolver has returned IPv4 results, but no IPv6 results yet, and the /// resolution delay has not yet elapsed. - case aResolvedWaiting - - /// The A query has resolved and the resolution delay has elapsed. We can - /// begin connecting immediately, but should not give up if we run out of - /// targets until the AAAA result returns. - case aResolvedConnecting + case resolvedWaiting - /// The AAAA query has resolved. We can begin connecting immediately, but - /// should not give up if we run out of targets until the AAAA result returns. - case aaaaResolved + /// The resolver has returned IPv6 results, or the resolution delay has elapsed. + /// We can begin connecting immediately, but should not give up if we run out of + /// targets until the resolver completes. + case resolvedConnecting /// All DNS results are in. We can make connection attempts until we run out /// of targets. @@ -179,13 +185,16 @@ internal final class HappyEyeballsConnector { /// Begin DNS resolution. case resolve - /// The A record lookup completed. - case resolverACompleted + /// The resolver received IPv4 results. + case resolverIPv4ResultsAvailable + + /// The resolver received IPv6 results. + case resolverIPv6ResultsAvailable - /// The AAAA record lookup completed. - case resolverAAAACompleted + /// The name resolution completed. + case resolverCompleted - /// The delay between the A result and the AAAA result has elapsed. + /// The delay between the IPv4 result and the IPv6 result has elapsed. case resolutionDelayElapsed /// The delay between starting one connection and the next has elapsed. @@ -206,7 +215,7 @@ internal final class HappyEyeballsConnector { } /// The DNS resolver provided by the user. - private let resolver: Resolver + private let resolver: NIOStreamingResolver /// The event loop this connector will run on. private let loop: EventLoop @@ -228,12 +237,14 @@ internal final class HappyEyeballsConnector { /// The channel builder callback takes an event loop and a protocol family as arguments. private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)> - /// The amount of time to wait for an AAAA response to come in after a A response is + /// The amount of time to wait for an IPv6 response to come in after a IPv4 response is /// received. By default this is 50ms. private let resolutionDelay: TimeAmount + + private var resolutionSession: Optional /// A reference to the task that will execute after the resolution delay expires, if - /// one is scheduled. This is held to ensure that we can cancel this task if the AAAA + /// one is scheduled. This is held to ensure that we can cancel this task if the IPv6 /// response comes in before the resolution delay expires. private var resolutionTask: Optional> @@ -268,17 +279,11 @@ internal final class HappyEyeballsConnector { /// and throw away all pending connection attempts that are no longer needed. private var pendingConnections: [EventLoopFuture<(Channel, ChannelBuilderResult)>] = [] - /// The number of DNS resolutions that have returned. - /// - /// This is used to keep track of whether we need to cancel the outstanding resolutions - /// during cleanup. - private var dnsResolutions: Int = 0 - /// An object that holds any errors we encountered. private var error: NIOConnectionError @inlinable - init(resolver: Resolver, + init(resolver: NIOStreamingResolver, loop: EventLoop, host: String, port: Int, @@ -292,6 +297,7 @@ internal final class HappyEyeballsConnector { self.port = port self.connectTimeout = connectTimeout self.channelBuilderCallback = channelBuilderCallback + self.resolutionSession = nil self.resolutionTask = nil self.connectionTask = nil self.timeoutTask = nil @@ -309,7 +315,7 @@ internal final class HappyEyeballsConnector { @inlinable convenience init( - resolver: Resolver, + resolver: NIOStreamingResolver, loop: EventLoop, host: String, port: Int, @@ -360,71 +366,64 @@ internal final class HappyEyeballsConnector { // Only one valid transition from idle: to start resolving. case (.idle, .resolve): state = .resolving - beginDNSResolution() + beginResolutionSession() - // In the resolving state, we can exit three ways: either the A query returns, - // the AAAA does, or the overall connect timeout fires. - case (.resolving, .resolverACompleted): - state = .aResolvedWaiting + // In the resolving state, we can exit four ways: either IPv4 results are available, IPv6 + // results are available, the resoler completes, or the overall connect timeout fires. + case (.resolving, .resolverIPv4ResultsAvailable): + state = .resolvedWaiting beginResolutionDelay() - case (.resolving, .resolverAAAACompleted): - state = .aaaaResolved + case (.resolving, .resolverIPv6ResultsAvailable): + state = .resolvedConnecting + beginConnecting() + case (.resolving, .resolverCompleted): + state = .allResolved beginConnecting() case (.resolving, .connectTimeoutElapsed): state = .complete timedOut() - // In the aResolvedWaiting state, we can exit three ways: the AAAA query returns, - // the resolution delay elapses, or the overall connect timeout fires. - case (.aResolvedWaiting, .resolverAAAACompleted): + // In the resolvedWaiting state, a number of inputs are valid: More IPv4 results can be + // available, IPv6 results can be available, the resolver can complete, the resolution + // delay can elapse, and the overall connect timeout can fire. + case (.resolvedWaiting, .resolverIPv4ResultsAvailable): + break + case (.resolvedWaiting, .resolverIPv6ResultsAvailable): + state = .resolvedConnecting + beginConnecting() + case (.resolvedWaiting, .resolverCompleted): state = .allResolved beginConnecting() - case (.aResolvedWaiting, .resolutionDelayElapsed): - state = .aResolvedConnecting + case (.resolvedWaiting, .resolutionDelayElapsed): + state = .resolvedConnecting beginConnecting() - case (.aResolvedWaiting, .connectTimeoutElapsed): + case (.resolvedWaiting, .connectTimeoutElapsed): state = .complete timedOut() - // In the aResolvedConnecting state, a number of inputs are valid: the AAAA result can - // return, the connectionDelay can elapse, the overall connection timeout can fire, - // a connection can succeed, a connection can fail, and we can run out of targets. - case (.aResolvedConnecting, .resolverAAAACompleted): - state = .allResolved + // In the resolvedConnecting state, a number of inputs are valid: More IPv4 or IPv6 results + // can be available, the resolver can complete, the connectionDelay can elapse, the overall + // connection timeout can fire, a connection can succeed, a connection can fail, and we can + // run out of targets. + case (.resolvedConnecting, .resolverIPv4ResultsAvailable): connectToNewTargets() - case (.aResolvedConnecting, .connectDelayElapsed): - connectionDelayElapsed() - case (.aResolvedConnecting, .connectTimeoutElapsed): - state = .complete - timedOut() - case (.aResolvedConnecting, .connectSuccess): - state = .complete - connectSuccess() - case (.aResolvedConnecting, .connectFailed): - connectFailed() - case (.aResolvedConnecting, .noTargetsRemaining): - // We are still waiting for the AAAA query, so we - // do nothing. - break - - // In the aaaaResolved state, a number of inputs are valid: the A result can return, - // the connectionDelay can elapse, the overall connection timeout can fire, a connection - // can succeed, a connection can fail, and we can run out of targets. - case (.aaaaResolved, .resolverACompleted): + case (.resolvedConnecting, .resolverIPv6ResultsAvailable): + connectToNewTargets() + case (.resolvedConnecting, .resolverCompleted): state = .allResolved connectToNewTargets() - case (.aaaaResolved, .connectDelayElapsed): + case (.resolvedConnecting, .connectDelayElapsed): connectionDelayElapsed() - case (.aaaaResolved, .connectTimeoutElapsed): + case (.resolvedConnecting, .connectTimeoutElapsed): state = .complete timedOut() - case (.aaaaResolved, .connectSuccess): + case (.resolvedConnecting, .connectSuccess): state = .complete connectSuccess() - case (.aaaaResolved, .connectFailed): + case (.resolvedConnecting, .connectFailed): connectFailed() - case (.aaaaResolved, .noTargetsRemaining): - // We are still waiting for the A query, so we + case (.resolvedConnecting, .noTargetsRemaining): + // We are still waiting for the IPv6 results, so we // do nothing. break @@ -450,8 +449,9 @@ internal final class HappyEyeballsConnector { // notifications, and can also get late scheduled task callbacks. We want to just quietly // ignore these, as our transition into the complete state should have already sent // cleanup messages to all of these things. - case (.complete, .resolverACompleted), - (.complete, .resolverAAAACompleted), + case (.complete, .resolverIPv4ResultsAvailable), + (.complete, .resolverIPv6ResultsAvailable), + (.complete, .resolverCompleted), (.complete, .connectSuccess), (.complete, .connectFailed), (.complete, .connectDelayElapsed), @@ -464,26 +464,35 @@ internal final class HappyEyeballsConnector { } /// Fire off a pair of DNS queries. - private func beginDNSResolution() { - // Per RFC 8305 Section 3, we need to send A and AAAA queries. - // The two queries SHOULD be made as soon after one another as possible, - // with the AAAA query made first and immediately followed by the A - // query. - // - // We hop back to `self.loop` because there's no guarantee the resolver runs - // on our event loop. - let aaaaLookup = self.resolver.initiateAAAAQuery(host: self.host, port: self.port).hop(to: self.loop) - self.whenAAAALookupComplete(future: aaaaLookup) - - let aLookup = self.resolver.initiateAQuery(host: self.host, port: self.port).hop(to: self.loop) - self.whenALookupComplete(future: aLookup) + private func beginResolutionSession() { + let resolutionSession = NIONameResolutionSession( + resultsHandler: { [self] results in + if self.loop.inEventLoop { + self.resolverDeliverResults(results) + } else { + self.loop.execute { + self.resolverDeliverResults(results) + } + } + }, completionHandler: { result in + if self.loop.inEventLoop { + self.resolutionComplete(result: result) + } else { + self.loop.execute { + self.resolutionComplete(result: result) + } + } + }, cancelledBy: self.resolutionPromise.futureResult + ) + self.resolutionSession = resolutionSession + resolver.resolve(name: self.host, destinationPort: self.port, session: resolutionSession) } - - /// Called when the A query has completed before the AAAA query. + + /// Called when IPv4 results are available before IPv6 results. /// /// Happy Eyeballs 2 prefers to connect over IPv6 if it's possible to do so. This means that - /// if the A lookup completes first we want to wait a small amount of time before we begin our - /// connection attempts, in the hope that the AAAA lookup will complete. + /// if only IPv4 results are available we want to wait a small amount of time before we begin + /// our connection attempts, in the hope that IPv6 results will be returned. /// /// This method sets off a scheduled task for the resolution delay. private func beginResolutionDelay() { @@ -621,8 +630,9 @@ internal final class HappyEyeballsConnector { private func cleanUp() { assert(self.state == .complete, "Clean up in invalid state \(self.state)") - if dnsResolutions < 2 { - resolver.cancelQueries() + if let resolutionSession = self.resolutionSession { + self.resolver.cancel(resolutionSession) + self.resolutionSession = nil } if let resolutionTask = self.resolutionTask { @@ -646,35 +656,32 @@ internal final class HappyEyeballsConnector { connection.whenSuccess { (channel, _) in channel.close(promise: nil) } } } + + private func resolverDeliverResults(_ results: [SocketAddress]) { + self.targets.resultsAvailable(results) + + let protocols = results.map(\.protocol) + if protocols.contains(.inet6) { + self.resolutionTask?.cancel() + self.resolutionTask = nil - /// A future callback that fires when a DNS A lookup completes. - private func whenALookupComplete(future: EventLoopFuture<[SocketAddress]>) { - future.map { results in - self.targets.aResultsAvailable(results) - }.recover { err in - self.error.dnsAError = err - }.whenComplete { (_: Result) in - self.dnsResolutions += 1 - self.processInput(.resolverACompleted) + self.processInput(.resolverIPv6ResultsAvailable) + } else if protocols.contains(.inet) { + self.processInput(.resolverIPv4ResultsAvailable) } } + + private func resolutionComplete(result: Result) { + if case .failure(let error) = result { + self.error.resolutionError = error + } - /// A future callback that fires when a DNS AAAA lookup completes. - private func whenAAAALookupComplete(future: EventLoopFuture<[SocketAddress]>) { - future.map { results in - self.targets.aaaaResultsAvailable(results) - }.recover { err in - self.error.dnsAAAAError = err - }.whenComplete { (_: Result) in - // It's possible that we were waiting to time out here, so if we were we should - // cancel that. - self.resolutionTask?.cancel() - self.resolutionTask = nil + self.resolutionSession = nil - self.dnsResolutions += 1 + self.resolutionTask?.cancel() + self.resolutionTask = nil - self.processInput(.resolverAAAACompleted) - } + self.processInput(.resolverCompleted) } /// A future callback that fires when the resolution delay completes. diff --git a/Sources/NIOPosix/NIOStreamingResolver.swift b/Sources/NIOPosix/NIOStreamingResolver.swift new file mode 100644 index 0000000000..9266024ded --- /dev/null +++ b/Sources/NIOPosix/NIOStreamingResolver.swift @@ -0,0 +1,240 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOConcurrencyHelpers + +/// A protocol that covers an object that performs name resolution. +/// +/// In general the rules for the resolver are relatively broad: there are no specific requirements on how +/// it operates. However, the rest of the code currently assumes that it obeys RFC 6724, particularly section 6 on +/// ordering returned addresses. That is, when possible, the IPv6 and IPv4 responses should be ordered by the destination +/// address ordering rules from that RFC. This specification is widely implemented by getaddrinfo +/// implementations, so any implementation based on getaddrinfo will work just fine. Other implementations +/// may need also to implement these sorting rules for the moment. +public protocol NIOStreamingResolver { + /// Start a name resolution for a given name. + /// + /// - parameters: + /// - name: The name to resolve. + /// - destinationPort: The port we'll be connecting to. + /// - session: The resolution session object associated with the resolution. + func resolve(name: String, destinationPort: Int, session: NIONameResolutionSession) + + /// Cancel an outstanding name resolution. + /// + /// This method is called whenever a name resolution that hasn't completed no longer has its + /// results needed. The resolver should, if possible, abort any outstanding queries and clean + /// up their state. + /// + /// This method is not guaranteed to terminate the outstanding queries. + /// + /// - parameters: + /// - session: The resolution session object associated with the resolution. + func cancel(_ session: NIONameResolutionSession) +} + +/// An object used by a resolver to deliver results and inform about the completion of a name +/// resolution. +/// +/// A resolution session object is associated with a single name resolution. +public final class NIONameResolutionSession: @unchecked Sendable { + private let lock: NIOLock + private var resultsHandler: ResultsHandler? + private var completionHandler: CompletionHandler? + + /// Create a `NIONameResolutionSession`. + /// + /// The `resultsHandler` and `completionHandler` closures are retained by the resolution session + /// until one of the following happens: + /// - All references to the resolution session are dropped. + /// - The `resolutionComplete` method is called. + /// - The `cancelledBy` future completes. + /// + /// - parameters: + /// - resultsHandler: A closure that will be fired when new results are delivered. + /// - completionHandler: A close that will be fired when the name resolution completes. + /// - cancelledBy: A future that will be completed when the resolution session is no longer + /// needed. + public init( + resultsHandler: @Sendable @escaping ([SocketAddress]) -> Void, + completionHandler: @Sendable @escaping (Result) -> Void, + cancelledBy: EventLoopFuture + ) { + self.lock = NIOLock() + self.resultsHandler = resultsHandler + self.completionHandler = completionHandler + cancelledBy.whenComplete { _ in + _ = self.invalidateAndReturnCompletionHandler() + } + } + + /// Create a `NIONameResolutionSession`. + /// + /// The `resultsHandler` and `completionHandler` closures are retained by the resolution session + /// until one of the following happens: + /// - All references to the resolution session are dropped. + /// - The `resolutionComplete` method is called. + /// + /// - parameters: + /// - resultsHandler: A closure that will be fired when new results are delivered. + /// - completionHandler: A close that will be fired when the name resolution completes. + public init( + resultsHandler: @Sendable @escaping ([SocketAddress]) -> Void, + completionHandler: @Sendable @escaping (Result) -> Void + ) { + self.lock = NIOLock() + self.resultsHandler = resultsHandler + self.completionHandler = completionHandler + } + + /// Deliver results for a name resolution. + /// + /// This method may be called any number of times until the name resolution completes. Calling + /// this method with an empty array is allowed. + /// + /// - parameters: + /// - results: Zero or more socket addresses. + public func deliverResults(_ results: [SocketAddress]) { + let resultsHandler = self.lock.withLock { self.resultsHandler } + resultsHandler?(results) + } + + /// Signal the completion of a name resolution. + /// + /// Calling this method invalidates the resolution session. That is, all handlers are released + /// and any future call to `deliverResults` or `resolutionComplete` will be silently ignored. + /// + /// - parameters: + /// - result: A `Result` value indicating whether the name resolution was successful or not. + public func resolutionComplete(_ result: Result) { + let completionHandler = self.invalidateAndReturnCompletionHandler() + completionHandler?(result) + } + + private func invalidateAndReturnCompletionHandler() -> CompletionHandler? { + self.lock.withLock { + let completionHandler = self.completionHandler + self.completionHandler = nil + self.resultsHandler = nil + return completionHandler + } + } + + private typealias ResultsHandler = ([SocketAddress]) -> Void + private typealias CompletionHandler = (Result) -> Void +} + +extension NIOStreamingResolver { + /// Start a non-cancellable name resolution that delivers all its result on completion. + /// + /// Results are accumulated until the name resolution completes. + /// + /// - parameters: + /// - name: The name to resolve. + /// - destinationPort: The port we'll be connecting to. + /// - eventLoop: The session associated with the resolution. + /// - returns: A future that will receive the name resolution results. + public func resolve( + name: String, destinationPort: Int, on eventLoop: EventLoop + ) -> EventLoopFuture<[SocketAddress]> { + let box = NIOLockedValueBox([] as [SocketAddress]) + let promise = eventLoop.makePromise(of: [SocketAddress].self) + + let session = NIONameResolutionSession { results in + box.withLockedValue { + $0.append(contentsOf: results) + } + } completionHandler: { result in + switch result { + case .success: + let results = box.withLockedValue({ $0 }) + promise.succeed(results) + case .failure(let error): + promise.fail(error) + } + } + + resolve(name: name, destinationPort: destinationPort, session: session) + return promise.futureResult + } +} + +/// A protocol that covers an object that does DNS lookups. +/// +/// In general the rules for the resolver are relatively broad: there are no specific requirements on how +/// it operates. However, the rest of the code assumes that it obeys RFC 6724, particularly section 6 on +/// ordering returned addresses. That is, the IPv6 and IPv4 responses should be ordered by the destination +/// address ordering rules from that RFC. This specification is widely implemented by getaddrinfo +/// implementations, so any implementation based on getaddrinfo will work just fine. In the future, a custom +/// resolver will need also to implement these sorting rules. +@available(*, deprecated, message: "Use NIOStreamingResoler instead.") +public protocol Resolver { + /// Initiate a DNS A query for a given host. + /// + /// - parameters: + /// - host: The hostname to do an A lookup on. + /// - port: The port we'll be connecting to. + /// - returns: An `EventLoopFuture` that fires with the result of the lookup. + func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> + + /// Initiate a DNS AAAA query for a given host. + /// + /// - parameters: + /// - host: The hostname to do an AAAA lookup on. + /// - port: The port we'll be connecting to. + /// - returns: An `EventLoopFuture` that fires with the result of the lookup. + func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> + + /// Cancel all outstanding DNS queries. + /// + /// This method is called whenever queries that have not completed no longer have their + /// results needed. The resolver should, if possible, abort any outstanding queries and + /// clean up their state. + /// + /// This method is not guaranteed to terminate the outstanding queries. + func cancelQueries() +} + +@available(*, deprecated) +internal struct NIOResolverToStreamingResolver: NIOStreamingResolver { + var resolver: Resolver + + func resolve(name: String, destinationPort: Int, session: NIONameResolutionSession) { + func deliverResults(_ results: [SocketAddress]) { + if !results.isEmpty { + session.deliverResults(results) + } + } + + let futures = [ + resolver.initiateAAAAQuery(host: name, port: destinationPort).map(deliverResults), + resolver.initiateAQuery(host: name, port: destinationPort).map(deliverResults), + ] + let eventLoop = futures[0].eventLoop + + EventLoopFuture.whenAllComplete(futures, on: eventLoop).whenSuccess { results in + for result in results { + if case .failure = result { + return session.resolutionComplete(result) + } + } + session.resolutionComplete(.success(())) + } + } + + func cancel(_ session: NIONameResolutionSession) { + resolver.cancelQueries() + } +} diff --git a/Sources/NIOPosix/Resolver.swift b/Sources/NIOPosix/Resolver.swift deleted file mode 100644 index 9c89c82cd1..0000000000 --- a/Sources/NIOPosix/Resolver.swift +++ /dev/null @@ -1,50 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import NIOCore - -/// A protocol that covers an object that does DNS lookups. -/// -/// In general the rules for the resolver are relatively broad: there are no specific requirements on how -/// it operates. However, the rest of the code assumes that it obeys RFC 6724, particularly section 6 on -/// ordering returned addresses. That is, the IPv6 and IPv4 responses should be ordered by the destination -/// address ordering rules from that RFC. This specification is widely implemented by getaddrinfo -/// implementations, so any implementation based on getaddrinfo will work just fine. In the future, a custom -/// resolver will need also to implement these sorting rules. -public protocol Resolver { - /// Initiate a DNS A query for a given host. - /// - /// - parameters: - /// - host: The hostname to do an A lookup on. - /// - port: The port we'll be connecting to. - /// - returns: An `EventLoopFuture` that fires with the result of the lookup. - func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> - - /// Initiate a DNS AAAA query for a given host. - /// - /// - parameters: - /// - host: The hostname to do an AAAA lookup on. - /// - port: The port we'll be connecting to. - /// - returns: An `EventLoopFuture` that fires with the result of the lookup. - func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> - - /// Cancel all outstanding DNS queries. - /// - /// This method is called whenever queries that have not completed no longer have their - /// results needed. The resolver should, if possible, abort any outstanding queries and - /// clean up their state. - /// - /// This method is not guaranteed to terminate the outstanding queries. - func cancelQueries() -} diff --git a/Tests/NIOPosixTests/BootstrapTest.swift b/Tests/NIOPosixTests/BootstrapTest.swift index 4fe103baee..211b39ac08 100644 --- a/Tests/NIOPosixTests/BootstrapTest.swift +++ b/Tests/NIOPosixTests/BootstrapTest.swift @@ -648,9 +648,9 @@ class BootstrapTest: XCTestCase { // Some platforms don't define "localhost" for IPv6, so check that // and use "ip6-localhost" instead. if !isIPv4 { - let hostResolver = GetaddrinfoResolver(loop: group.next(), aiSocktype: .stream, aiProtocol: .tcp) - let hostv6 = try! hostResolver.initiateAAAAQuery(host: "localhost", port: 8088).wait() - if hostv6.isEmpty { + let hostResolver = GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) + let hostv6 = try! hostResolver.resolve(name: "localhost", destinationPort: 8088, on: group.next()).wait() + if !hostv6.map(\.protocol).contains(.inet6) { localhost = "ip6-localhost" } } diff --git a/Tests/NIOPosixTests/GetAddrInfoResolverTest.swift b/Tests/NIOPosixTests/GetAddrInfoResolverTest.swift index 19c47763eb..bfb2fc9cf4 100644 --- a/Tests/NIOPosixTests/GetAddrInfoResolverTest.swift +++ b/Tests/NIOPosixTests/GetAddrInfoResolverTest.swift @@ -24,15 +24,12 @@ class GetaddrinfoResolverTest: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let resolver = GetaddrinfoResolver(loop: group.next(), aiSocktype: .stream, aiProtocol: .tcp) - let v4Future = resolver.initiateAQuery(host: "127.0.0.1", port: 12345) - let v6Future = resolver.initiateAAAAQuery(host: "127.0.0.1", port: 12345) - - let addressV4 = try v4Future.wait() - let addressV6 = try v6Future.wait() - XCTAssertEqual(1, addressV4.count) - XCTAssertEqual(try SocketAddress(ipAddress: "127.0.0.1", port: 12345), addressV4[0]) - XCTAssertTrue(addressV6.isEmpty) + let resolver = GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) + let future = resolver.resolve(name: "127.0.0.1", destinationPort: 12345, on: group.next()) + + let results = try future.wait() + XCTAssertEqual(1, results.count) + XCTAssertEqual(try SocketAddress(ipAddress: "127.0.0.1", port: 12345), results[0]) } func testResolveNoDuplicatesV6() throws { @@ -41,14 +38,11 @@ class GetaddrinfoResolverTest: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let resolver = GetaddrinfoResolver(loop: group.next(), aiSocktype: .stream, aiProtocol: .tcp) - let v4Future = resolver.initiateAQuery(host: "::1", port: 12345) - let v6Future = resolver.initiateAAAAQuery(host: "::1", port: 12345) + let resolver = GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) + let future = resolver.resolve(name: "::1", destinationPort: 12345, on: group.next()) - let addressV4 = try v4Future.wait() - let addressV6 = try v6Future.wait() - XCTAssertEqual(1, addressV6.count) - XCTAssertEqual(try SocketAddress(ipAddress: "::1", port: 12345), addressV6[0]) - XCTAssertTrue(addressV4.isEmpty) + let results = try future.wait() + XCTAssertEqual(1, results.count) + XCTAssertEqual(try SocketAddress(ipAddress: "::1", port: 12345), results[0]) } } diff --git a/Tests/NIOPosixTests/HappyEyeballsTest.swift b/Tests/NIOPosixTests/HappyEyeballsTest.swift index 88709d10d4..c9404ad63b 100644 --- a/Tests/NIOPosixTests/HappyEyeballsTest.swift +++ b/Tests/NIOPosixTests/HappyEyeballsTest.swift @@ -236,7 +236,7 @@ private func buildEyeballer( ) -> (eyeballer: HappyEyeballsConnector, resolver: DummyResolver, loop: EmbeddedEventLoop) { let loop = EmbeddedEventLoop() let resolver = DummyResolver(loop: loop) - let eyeballer = HappyEyeballsConnector(resolver: resolver, + let eyeballer = HappyEyeballsConnector(resolver: NIOResolverToStreamingResolver(resolver: resolver), loop: loop, host: host, port: port, @@ -286,12 +286,13 @@ public final class HappyEyeballsTest : XCTestCase { let target = try targetFuture.wait() XCTAssertEqual(target!, "fe80::1") - // We should have had queries for AAAA and A. let expectedQueries: [DummyResolver.Event] = [ .aaaa(host: "example.com", port: 80), .a(host: "example.com", port: 80) ] - XCTAssertEqual(resolver.events, expectedQueries) + // We should have had queries for AAAA and A. We should then have had a cancel, because the + // connection succeeds before the resolver completes. + XCTAssertEqual(resolver.events, expectedQueries + [.cancel]) } func testTimeOutDuringDNSResolution() throws { @@ -430,8 +431,9 @@ public final class HappyEyeballsTest : XCTestCase { let target = try targetFuture.wait() XCTAssertEqual(target!, "fe80::1") - // We should have had queries for AAAA and A, with no cancel. - XCTAssertEqual(resolver.events, expectedQueries) + // We should have had queries for AAAA and A. We should then have had a cancel, because the + // connection succeeds before the resolver completes. + XCTAssertEqual(resolver.events, expectedQueries + [.cancel]) } func testAQueryReturningFirstThenAAAAErrors() throws { @@ -525,8 +527,7 @@ public final class HappyEyeballsTest : XCTestCase { if let error = channelFuture.getError() as? NIOConnectionError { XCTAssertEqual(error.host, "example.com") XCTAssertEqual(error.port, 80) - XCTAssertNil(error.dnsAError) - XCTAssertNil(error.dnsAAAAError) + XCTAssertNil(error.resolutionError) XCTAssertEqual(error.connectionErrors.count, 0) } else { XCTFail("Got \(String(describing: channelFuture.getError()))") @@ -557,8 +558,7 @@ public final class HappyEyeballsTest : XCTestCase { if let error = channelFuture.getError() as? NIOConnectionError { XCTAssertEqual(error.host, "example.com") XCTAssertEqual(error.port, 80) - XCTAssertEqual(error.dnsAError as? DummyError ?? DummyError(), v4Error) - XCTAssertEqual(error.dnsAAAAError as? DummyError ?? DummyError(), v6Error) + XCTAssertEqual(error.resolutionError as? DummyError ?? DummyError(), v6Error) XCTAssertEqual(error.connectionErrors.count, 0) } else { XCTFail("Got \(String(describing: channelFuture.getError()))") @@ -692,8 +692,7 @@ public final class HappyEyeballsTest : XCTestCase { if let error = channelFuture.getError() as? NIOConnectionError { XCTAssertEqual(error.host, "example.com") XCTAssertEqual(error.port, 80) - XCTAssertNil(error.dnsAError) - XCTAssertNil(error.dnsAAAAError) + XCTAssertNil(error.resolutionError) XCTAssertEqual(error.connectionErrors.count, 20) for (idx, error) in error.connectionErrors.enumerated() { @@ -1240,7 +1239,7 @@ public final class HappyEyeballsTest : XCTestCase { // Tests a regression where the happy eyeballs connector would update its state on the event // loop of the future returned by the resolver (which may be different to its own). Prior // to the fix this test would trigger TSAN warnings. - let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } @@ -1254,10 +1253,8 @@ public final class HappyEyeballsTest : XCTestCase { } // Run the resolver and connection on different event loops. - let resolverLoop = group.next() let connectionLoop = group.next() - XCTAssertNotIdentical(resolverLoop, connectionLoop) - let resolver = GetaddrinfoResolver(loop: resolverLoop, aiSocktype: .stream, aiProtocol: .tcp) + let resolver = GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) let client = try ClientBootstrap(group: connectionLoop) .resolver(resolver) .connect(host: "localhost", port: server.localAddress!.port!) diff --git a/Tests/NIOPosixTests/TestUtils.swift b/Tests/NIOPosixTests/TestUtils.swift index 0c38f94d8f..10fe885284 100644 --- a/Tests/NIOPosixTests/TestUtils.swift +++ b/Tests/NIOPosixTests/TestUtils.swift @@ -357,16 +357,13 @@ func resolverDebugInformation(eventLoop: EventLoop, host: String, previouslyRece return __testOnly_addressDescription(sa.address) } } - let res = GetaddrinfoResolver(loop: eventLoop, aiSocktype: .stream, aiProtocol: .tcp) - let ipv6Results = try assertNoThrowWithValue(res.initiateAAAAQuery(host: host, port: 0).wait()).map(printSocketAddress) - let ipv4Results = try assertNoThrowWithValue(res.initiateAQuery(host: host, port: 0).wait()).map(printSocketAddress) + let res = GetaddrinfoResolver(aiSocktype: .stream, aiProtocol: .tcp) + let results = try assertNoThrowWithValue(res.resolve(name: host, destinationPort: 0, on: eventLoop).wait()).map(printSocketAddress) return """ when trying to resolve '\(host)' we've got the following results: - previous try: \(printSocketAddress(previouslyReceivedResult)) - - all results: - IPv4: \(ipv4Results) - IPv6: \(ipv6Results) + - all results: \(results) """ }