-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP Add connection pooling for CIO client #4407
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
/* | ||
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. | ||
*/ | ||
|
||
package io.ktor.client.engine.cio | ||
|
||
import io.ktor.network.selector.* | ||
import io.ktor.network.sockets.* | ||
import io.ktor.utils.io.* | ||
import kotlinx.coroutines.* | ||
import kotlinx.coroutines.debug.junit5.* | ||
import org.junit.jupiter.api.* | ||
import org.junit.jupiter.api.Assertions.* | ||
import kotlin.time.Duration.Companion.milliseconds | ||
|
||
@CoroutinesTimeout(5 * 1000) | ||
class ConnectionFactoryPoolingTest { | ||
|
||
private lateinit var selector: SelectorManager | ||
private lateinit var factory: ConnectionFactory | ||
private lateinit var server: ServerSocket | ||
private lateinit var serverJob: Job | ||
private lateinit var serverAddress: InetSocketAddress | ||
|
||
init { | ||
} | ||
|
||
@BeforeEach | ||
fun setUp() { | ||
runBlocking { | ||
selector = SelectorManager(Dispatchers.Default) | ||
factory = ConnectionFactory( | ||
selector = selector, connectionsLimit = 10, addressConnectionsLimit = 5, keepAliveTime = 1000 | ||
) | ||
|
||
// Set up local server | ||
server = aSocket(selector).tcp().bind(InetSocketAddress("localhost", 0)) | ||
serverAddress = server.localAddress as InetSocketAddress | ||
serverJob = CoroutineScope(Dispatchers.IO).launch { | ||
|
||
while (isActive) { | ||
val socket = server.accept() | ||
launch { | ||
val input = socket.openReadChannel() | ||
val output = socket.openWriteChannel() | ||
try { | ||
while (isActive) { | ||
val line = input.readUTF8Line() ?: break | ||
output.writeStringUtf8("$line\n") | ||
output.flush() | ||
} | ||
} finally { | ||
socket.close() | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
@AfterEach | ||
fun tearDown() { | ||
runBlocking { | ||
serverJob.cancelAndJoin() | ||
|
||
server.close() | ||
selector.close() | ||
} | ||
} | ||
|
||
@Test | ||
fun `connect should create new connection when pool is empty`() = runBlocking { | ||
val socket = factory.connect(serverAddress) | ||
|
||
|
||
assertNotNull(socket) | ||
assertTrue(socket.isActive) | ||
} | ||
|
||
@Test | ||
fun `connect should reuse connection from pool`() = runBlocking { | ||
val socket1 = factory.connect(serverAddress) | ||
|
||
factory.release(socket1) | ||
val socket2 = factory.connect(serverAddress) | ||
|
||
assertEquals(socket1, socket2) | ||
} | ||
|
||
@Test | ||
fun `connect should create new connection when pooled connection is expired`() = runBlocking { | ||
val socket1 = factory.connect(serverAddress) | ||
|
||
factory.release(socket1) | ||
delay(1100.milliseconds) // Wait for the connection to expire | ||
val socket2 = factory.connect(serverAddress) | ||
|
||
assertNotEquals(socket1, socket2) | ||
} | ||
|
||
@Test | ||
fun `connect should respect address connections limit`() = runBlocking { | ||
val sockets = List(5) { factory.connect(serverAddress) } | ||
assertThrows<TimeoutCancellationException> { | ||
withTimeout(100) { | ||
factory.connect(serverAddress) | ||
} | ||
} | ||
|
||
sockets.forEach { factory.release(it) } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,56 @@ | ||
/* | ||
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. | ||
*/ | ||
|
||
package io.ktor.client.engine.cio | ||
|
||
import io.ktor.network.selector.* | ||
import io.ktor.network.sockets.* | ||
import io.ktor.util.collections.* | ||
import io.ktor.util.date.* | ||
import io.ktor.util.logging.* | ||
import io.ktor.utils.io.* | ||
import io.ktor.utils.io.locks.* | ||
import kotlinx.coroutines.sync.* | ||
|
||
private val LOG = KtorSimpleLogger("io.ktor.client.engine.cio.ConnectionFactory") | ||
|
||
|
||
@OptIn(InternalAPI::class) | ||
internal class ConnectionFactory( | ||
private val selector: SelectorManager, | ||
connectionsLimit: Int, | ||
private val addressConnectionsLimit: Int | ||
) { | ||
private val connectionsLimit: Int, | ||
private val addressConnectionsLimit: Int, | ||
private val keepAliveTime: Long = 30_000, // Default keep-alive time in milliseconds | ||
private val maxPoolSize: Int = 100 // Maximum pool size | ||
) : SynchronizedObject() { | ||
|
||
private val limit = Semaphore(connectionsLimit) | ||
private val addressLimit = ConcurrentMap<InetSocketAddress, Semaphore>() | ||
private val addressLimit = ConcurrentMap<SocketAddress, Semaphore>() | ||
private val connectionPool = mutableMapOf<SocketAddress, MutableList<PooledConnection>>() | ||
|
||
suspend fun connect( | ||
address: InetSocketAddress, | ||
configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {} | ||
): Socket { | ||
|
||
LOG.trace { "Attempting to connect to address: $address" } | ||
// Try to get a connection from the pool | ||
val pooledConnection = getPooledConnection(address) | ||
if (pooledConnection != null) { | ||
LOG.trace { "Reusing pooled connection for address: $address" } | ||
return pooledConnection.socket // Return the socket from the pooled connection | ||
} | ||
|
||
// If no pooled connection is available, create a new one | ||
limit.acquire() | ||
LOG.trace { "No pooled connection available, creating new connection to address: $address" } | ||
return try { | ||
val addressSemaphore = addressLimit.computeIfAbsent(address) { Semaphore(addressConnectionsLimit) } | ||
addressSemaphore.acquire() | ||
|
||
try { | ||
aSocket(selector).tcp().connect(address, configuration) | ||
val socket = aSocket(selector).tcp().connect(address, configuration) | ||
LOG.trace { "Successfully connected to address: $address" } | ||
socket | ||
} catch (cause: Throwable) { | ||
// a failure or cancellation | ||
LOG.error("Failed to connect to address: $address", cause) | ||
addressSemaphore.release() | ||
throw cause | ||
} | ||
|
@@ -39,8 +60,63 @@ internal class ConnectionFactory( | |
} | ||
} | ||
|
||
fun release(address: InetSocketAddress) { | ||
addressLimit[address]!!.release() | ||
limit.release() | ||
fun release(socket: Socket) { | ||
LOG.trace { "Releasing connection for address: ${socket.remoteAddress}" } | ||
|
||
if (socket.isClosed) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This flag seems out of sync with |
||
LOG.warn("Attempted to release a closed connection for address: ${socket.remoteAddress}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On JVM |
||
return | ||
} | ||
|
||
val pooledConnection = PooledConnection(socket, GMTDate().timestamp) | ||
synchronized(this) { | ||
val pool = connectionPool.getOrPut(socket.remoteAddress) { mutableListOf() } | ||
pool.add(pooledConnection) | ||
|
||
// If pool size exceeds maxPoolSize, remove and close excess connections | ||
while (pool.size > maxPoolSize) { | ||
LOG.trace { "Pool size exceeded maxPoolSize, removing oldest connection" } | ||
pool.removeAt(0).close() | ||
} | ||
|
||
limit.release() // Release the semaphore to allow new connections | ||
addressLimit[socket.remoteAddress]?.release() // Release the address semaphore | ||
} | ||
LOG.trace { "Connection released for address: ${socket.remoteAddress}" } | ||
} | ||
|
||
private fun getPooledConnection(address: InetSocketAddress): PooledConnection? { | ||
LOG.trace { "Checking for pooled connection for address: $address" } | ||
val connections = connectionPool[address] ?: return null | ||
val currentTime = GMTDate().timestamp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wondering why we don't use kotlinx.datetime? Are we waiting for a stable version? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found the issue: KTOR-2721 Migrate from GMTDate to kotlinx.datetime.Instant |
||
|
||
synchronized(this) { | ||
connections.removeAll { | ||
if (currentTime - it.lastUsed > keepAliveTime) { | ||
LOG.trace { "Removing expired connection for address: $address" } | ||
it.close() | ||
return@removeAll true | ||
} | ||
false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also remove closed connections here? Is it possible to have a closed connection in the pool? |
||
} | ||
|
||
if (connections.isNotEmpty()) { | ||
val connection = connections.removeAt(0) | ||
LOG.trace { "Pooled connection found for address: $address" } | ||
return connection | ||
} | ||
} | ||
LOG.trace { "No pooled connection available for address: $address" } | ||
return null | ||
} | ||
} | ||
|
||
private class PooledConnection( | ||
val socket: Socket, | ||
var lastUsed: Long | ||
) { | ||
fun close() { | ||
LOG.trace { "Closing socket for address: ${socket.remoteAddress}" } | ||
socket.close() | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,7 +94,6 @@ internal class Endpoint( | |
deliveryPoint.send(task) | ||
} | ||
|
||
@OptIn(InternalAPI::class) | ||
private suspend fun makeDedicatedRequest( | ||
request: HttpRequestData, | ||
callContext: CoroutineContext | ||
|
@@ -118,7 +117,7 @@ internal class Endpoint( | |
} catch (cause: Throwable) { | ||
LOGGER.debug("An error occurred while closing connection", cause) | ||
} finally { | ||
releaseConnection() | ||
releaseConnection(connection.socket) | ||
} | ||
} | ||
|
||
|
@@ -193,7 +192,7 @@ internal class Endpoint( | |
coroutineContext | ||
) | ||
|
||
pipeline.pipelineContext.invokeOnCompletion { releaseConnection() } | ||
pipeline.pipelineContext.invokeOnCompletion { releaseConnection(connection.socket) } | ||
} | ||
|
||
@Suppress("UNUSED_EXPRESSION") | ||
|
@@ -248,7 +247,7 @@ internal class Endpoint( | |
} catch (_: Throwable) { | ||
} | ||
|
||
connectionFactory.release(address) | ||
connectionFactory.release(socket) | ||
throw cause | ||
} | ||
} | ||
|
@@ -288,9 +287,8 @@ internal class Endpoint( | |
return connectTimeout to socketTimeout | ||
} | ||
|
||
private fun releaseConnection() { | ||
val address = connectionAddress ?: return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we remove the field |
||
connectionFactory.release(address) | ||
private fun releaseConnection(socket: Socket) { | ||
connectionFactory.release(socket) | ||
connections.decrementAndGet() | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use
Duration
here?