Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ extension HTTPClient {
try HTTPClientRequest.Prepared(
currentRequest,
dnsOverride: configuration.dnsOverride,
localAddress: configuration.localAddress,
tracing: self.configuration.tracing
)
let response = try await {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ extension HTTPClientRequest.Prepared {
init(
_ request: HTTPClientRequest,
dnsOverride: [String: String] = [:],
localAddress: String? = nil,
tracing: HTTPClient.TracingConfiguration? = nil
) throws {
guard !request.url.isEmpty, let url = URL(string: request.url) else {
Expand All @@ -73,7 +74,12 @@ extension HTTPClientRequest.Prepared {

self.init(
url: url,
poolKey: .init(url: deconstructedURL, tlsConfiguration: request.tlsConfiguration, dnsOverride: dnsOverride),
poolKey: .init(
url: deconstructedURL,
tlsConfiguration: request.tlsConfiguration,
dnsOverride: dnsOverride,
localAddress: request.localAddress ?? localAddress
),
requestFramingMetadata: metadata,
head: .init(
version: .http1_1,
Expand Down Expand Up @@ -140,6 +146,7 @@ extension HTTPClientRequest {
newRequest.method = method
newRequest.headers = headers
newRequest.body = body
newRequest.localAddress = self.localAddress
return newRequest
}
}
8 changes: 8 additions & 0 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,20 @@ public struct HTTPClientRequest: Sendable {
/// Request-specific TLS configuration, defaults to no request-specific TLS configuration.
public var tlsConfiguration: TLSConfiguration?

/// The local IP address to bind this request's connection to.
///
/// When set, overrides ``HTTPClient/Configuration/localAddress`` for this request.
/// The value should be an IP address string (e.g. `"192.168.1.10"` or `"::1"`).
/// Defaults to `nil` (use client configuration default).
public var localAddress: String?

public init(url: String) {
self.url = url
self.method = .GET
self.headers = .init()
self.body = .none
self.tlsConfiguration = nil
self.localAddress = nil
}
}

Expand Down
26 changes: 20 additions & 6 deletions Sources/AsyncHTTPClient/ConnectionPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@ enum ConnectionPool {
var connectionTarget: ConnectionTarget
private var tlsConfiguration: BestEffortHashableTLSConfiguration?
var serverNameIndicatorOverride: String?
var localAddress: String?

init(
scheme: Scheme,
connectionTarget: ConnectionTarget,
tlsConfiguration: BestEffortHashableTLSConfiguration? = nil,
serverNameIndicatorOverride: String?
serverNameIndicatorOverride: String?,
localAddress: String? = nil
) {
self.scheme = scheme
self.connectionTarget = connectionTarget
self.tlsConfiguration = tlsConfiguration
self.serverNameIndicatorOverride = serverNameIndicatorOverride
self.localAddress = localAddress
}

var description: String {
Expand All @@ -75,8 +78,12 @@ enum ConnectionPool {
case .unixSocket(let socketPath):
hostDescription = socketPath
}
return
var result =
"\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash)"
if let addr = self.localAddress {
result += " bind: \(addr)"
}
return result
}
}
}
Expand All @@ -97,23 +104,30 @@ extension DeconstructedURL {
}

extension ConnectionPool.Key {
init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?, dnsOverride: [String: String]) {
init(
url: DeconstructedURL,
tlsConfiguration: TLSConfiguration?,
dnsOverride: [String: String],
localAddress: String? = nil
) {
let (connectionTarget, serverNameIndicatorOverride) = url.applyDNSOverride(dnsOverride)
self.init(
scheme: url.scheme,
connectionTarget: connectionTarget,
tlsConfiguration: tlsConfiguration.map {
BestEffortHashableTLSConfiguration(wrapping: $0)
},
serverNameIndicatorOverride: serverNameIndicatorOverride
serverNameIndicatorOverride: serverNameIndicatorOverride,
localAddress: localAddress
)
}

init(_ request: HTTPClient.Request, dnsOverride: [String: String] = [:]) {
init(_ request: HTTPClient.Request, dnsOverride: [String: String] = [:], localAddress: String? = nil) {
self.init(
url: request.deconstructedURL,
tlsConfiguration: request.tlsConfiguration,
dnsOverride: dnsOverride
dnsOverride: dnsOverride,
localAddress: localAddress
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,19 @@ extension HTTPConnectionPool.ConnectionFactory {
promise: EventLoopPromise<NegotiatedProtocol>
) {
precondition(!self.key.scheme.usesTLS, "Unexpected scheme")
return self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
).connect(target: self.key.connectionTarget).map {
.http1_1($0)
}.cascade(to: promise)
do {
let bootstrap = try self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
)
bootstrap.connect(target: self.key.connectionTarget).map {
.http1_1($0)
}.cascade(to: promise)
} catch {
promise.fail(error)
}
}

private func makeHTTPProxyChannel<Requester: HTTPConnectionRequester>(
Expand All @@ -267,12 +272,18 @@ extension HTTPConnectionPool.ConnectionFactory {
// A proxy connection starts with a plain text connection to the proxy server. After
// the connection has been established with the proxy server, the connection might be
// upgraded to TLS before we send our first request.
let bootstrap = self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
)
let bootstrap: NIOClientTCPBootstrapProtocol
do {
bootstrap = try self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
)
} catch {
promise.fail(error)
return
}
bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in
switch result {
case .success(let channel):
Expand Down Expand Up @@ -321,12 +332,18 @@ extension HTTPConnectionPool.ConnectionFactory {
// A proxy connection starts with a plain text connection to the proxy server. After
// the connection has been established with the proxy server, the connection might be
// upgraded to TLS before we send our first request.
let bootstrap = self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
)
let bootstrap: NIOClientTCPBootstrapProtocol
do {
bootstrap = try self.makePlainBootstrap(
requester: requester,
connectionID: connectionID,
deadline: deadline,
eventLoop: eventLoop
)
} catch {
promise.fail(error)
return
}
bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in
switch result {
case .success(let channel):
Expand Down Expand Up @@ -421,12 +438,16 @@ extension HTTPConnectionPool.ConnectionFactory {
connectionID: HTTPConnectionPool.Connection.ID,
deadline: NIODeadline,
eventLoop: EventLoop
) -> NIOClientTCPBootstrapProtocol {
) throws -> NIOClientTCPBootstrapProtocol {
if let localAddress = self.key.localAddress, !localAddress.isIPAddress {
throw HTTPClientError.invalidLocalAddress
}

#if canImport(Network)
if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *),
let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop)
{
return
var bootstrap =
tsBootstrap
.channelOption(
NIOTSChannelOptions.waitForActivity,
Expand All @@ -448,14 +469,32 @@ extension HTTPConnectionPool.ConnectionFactory {
return channel.eventLoop.makeFailedFuture(error)
}
}
if let localAddress = self.key.localAddress {
bootstrap = bootstrap.configureNWParameters { params in
params.requiredLocalEndpoint = NWEndpoint.hostPort(
host: NWEndpoint.Host(localAddress),
port: .any
)
}
}
return bootstrap
}
#endif

if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) {
return
var bootstrap =
nioBootstrap
.connectTimeout(deadline - NIODeadline.now())
.enableMPTCP(clientConfiguration.enableMultipath)
if let localAddress = self.key.localAddress {
do {
let socketAddress = try SocketAddress(ipAddress: localAddress, port: 0)
bootstrap = bootstrap.bind(to: socketAddress)
} catch {
throw HTTPClientError.invalidLocalAddress
}
}
return bootstrap
}

preconditionFailure("No matching bootstrap found")
Expand Down Expand Up @@ -523,6 +562,10 @@ extension HTTPConnectionPool.ConnectionFactory {
eventLoop: EventLoop,
logger: Logger
) -> EventLoopFuture<NIOClientTCPBootstrapProtocol> {
if let localAddress = self.key.localAddress, !localAddress.isIPAddress {
return eventLoop.makeFailedFuture(HTTPClientError.invalidLocalAddress)
}

var tlsConfig = self.tlsConfiguration
switch self.clientConfiguration.httpVersion.configuration {
case .automatic:
Expand All @@ -538,13 +581,14 @@ extension HTTPConnectionPool.ConnectionFactory {
#if canImport(Network)
if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), eventLoop is QoSEventLoop {
// create NIOClientTCPBootstrap with NIOTS TLS provider
let localAddr = self.key.localAddress
let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(
on: eventLoop,
serverNameIndicatorOverride: key.serverNameIndicatorOverride
).map {
options -> NIOClientTCPBootstrapProtocol in

NIOTSConnectionBootstrap(group: eventLoop) // validated above
var bootstrap = NIOTSConnectionBootstrap(group: eventLoop) // validated above
.channelOption(
NIOTSChannelOptions.waitForActivity,
value: self.clientConfiguration.networkFrameworkWaitForConnectivity
Expand All @@ -569,7 +613,16 @@ extension HTTPConnectionPool.ConnectionFactory {
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
} as NIOClientTCPBootstrapProtocol
}
if let localAddress = localAddr {
bootstrap = bootstrap.configureNWParameters { params in
params.requiredLocalEndpoint = NWEndpoint.hostPort(
host: NWEndpoint.Host(localAddress),
port: .any
)
}
}
return bootstrap as NIOClientTCPBootstrapProtocol
}
return bootstrapFuture
}
Expand All @@ -581,10 +634,20 @@ extension HTTPConnectionPool.ConnectionFactory {
logger: logger
)

return eventLoop.submit {
ClientBootstrap(group: eventLoop)
return eventLoop.submit { [key] () throws -> NIOClientTCPBootstrapProtocol in
var bootstrap = ClientBootstrap(group: eventLoop)
.connectTimeout(deadline - NIODeadline.now())
.enableMPTCP(clientConfiguration.enableMultipath)
if let localAddress = key.localAddress {
do {
let socketAddress = try SocketAddress(ipAddress: localAddress, port: 0)
bootstrap = bootstrap.bind(to: socketAddress)
} catch {
throw HTTPClientError.invalidLocalAddress
}
}
return
bootstrap
Comment on lines +649 to +650
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swift-format 🙄

.channelInitializer { channel in
sslContextFuture.flatMap { sslContext -> EventLoopFuture<Void> in
do {
Expand Down
9 changes: 7 additions & 2 deletions Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ struct RequestOptions {
var idleWriteTimeout: TimeAmount?
/// DNS overrides.
var dnsOverride: [String: String]
/// The local IP address to bind outgoing connections to.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment could be a bit more. when and why should you use it?

var localAddress: String?

init(
idleReadTimeout: TimeAmount?,
idleWriteTimeout: TimeAmount?,
dnsOverride: [String: String]
dnsOverride: [String: String],
localAddress: String? = nil
) {
self.idleReadTimeout = idleReadTimeout
self.idleWriteTimeout = idleWriteTimeout
self.dnsOverride = dnsOverride
self.localAddress = localAddress
}
}

Expand All @@ -38,7 +42,8 @@ extension RequestOptions {
RequestOptions(
idleReadTimeout: configuration.timeout.read,
idleWriteTimeout: configuration.timeout.write,
dnsOverride: configuration.dnsOverride
dnsOverride: configuration.dnsOverride,
localAddress: configuration.localAddress
)
}
}
Loading