Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add connection pooling for CIO client #4407

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.engine.cio

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.utils.io.*
import kotlinx.coroutines.*
import kotlinx.coroutines.debug.junit5.*
import org.junit.jupiter.api.*
import org.junit.jupiter.api.Assertions.*
import kotlin.time.Duration.Companion.milliseconds

@CoroutinesTimeout(5 * 1000)
class ConnectionFactoryPoolingTest {

private lateinit var selector: SelectorManager
private lateinit var factory: ConnectionFactory
private lateinit var server: ServerSocket
private lateinit var serverJob: Job
private lateinit var serverAddress: InetSocketAddress

init {
}

@BeforeEach
fun setUp() {
runBlocking {
selector = SelectorManager(Dispatchers.Default)
factory = ConnectionFactory(
selector = selector, connectionsLimit = 10, addressConnectionsLimit = 5, keepAliveTime = 1000
)

// Set up local server
server = aSocket(selector).tcp().bind(InetSocketAddress("localhost", 0))
serverAddress = server.localAddress as InetSocketAddress
serverJob = CoroutineScope(Dispatchers.IO).launch {

while (isActive) {
val socket = server.accept()
launch {
val input = socket.openReadChannel()
val output = socket.openWriteChannel()
try {
while (isActive) {
val line = input.readUTF8Line() ?: break
output.writeStringUtf8("$line\n")
output.flush()
}
} finally {
socket.close()
}
}
}
}
}
}

@AfterEach
fun tearDown() {
runBlocking {
serverJob.cancelAndJoin()

server.close()
selector.close()
}
}

@Test
fun `connect should create new connection when pool is empty`() = runBlocking {
val socket = factory.connect(serverAddress)


assertNotNull(socket)
assertTrue(socket.isActive)
}

@Test
fun `connect should reuse connection from pool`() = runBlocking {
val socket1 = factory.connect(serverAddress)

factory.release(socket1)
val socket2 = factory.connect(serverAddress)

assertEquals(socket1, socket2)
}

@Test
fun `connect should create new connection when pooled connection is expired`() = runBlocking {
val socket1 = factory.connect(serverAddress)

factory.release(socket1)
delay(1100.milliseconds) // Wait for the connection to expire
val socket2 = factory.connect(serverAddress)

assertNotEquals(socket1, socket2)
}

@Test
fun `connect should respect address connections limit`() = runBlocking {
val sockets = List(5) { factory.connect(serverAddress) }
assertThrows<TimeoutCancellationException> {
withTimeout(100) {
factory.connect(serverAddress)
}
}

sockets.forEach { factory.release(it) }
}
}
Original file line number Diff line number Diff line change
@@ -1,35 +1,56 @@
/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.engine.cio

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.util.collections.*
import io.ktor.util.date.*
import io.ktor.util.logging.*
import io.ktor.utils.io.*
import io.ktor.utils.io.locks.*
import kotlinx.coroutines.sync.*

private val LOG = KtorSimpleLogger("io.ktor.client.engine.cio.ConnectionFactory")


@OptIn(InternalAPI::class)
internal class ConnectionFactory(
private val selector: SelectorManager,
connectionsLimit: Int,
private val addressConnectionsLimit: Int
) {
private val connectionsLimit: Int,
private val addressConnectionsLimit: Int,
private val keepAliveTime: Long = 30_000, // Default keep-alive time in milliseconds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use Duration here?

private val maxPoolSize: Int = 100 // Maximum pool size
) : SynchronizedObject() {

private val limit = Semaphore(connectionsLimit)
private val addressLimit = ConcurrentMap<InetSocketAddress, Semaphore>()
private val addressLimit = ConcurrentMap<SocketAddress, Semaphore>()
private val connectionPool = mutableMapOf<SocketAddress, MutableList<PooledConnection>>()

suspend fun connect(
address: InetSocketAddress,
configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {}
): Socket {

LOG.trace { "Attempting to connect to address: $address" }
// Try to get a connection from the pool
val pooledConnection = getPooledConnection(address)
if (pooledConnection != null) {
LOG.trace { "Reusing pooled connection for address: $address" }
return pooledConnection.socket // Return the socket from the pooled connection
}

// If no pooled connection is available, create a new one
limit.acquire()
LOG.trace { "No pooled connection available, creating new connection to address: $address" }
return try {
val addressSemaphore = addressLimit.computeIfAbsent(address) { Semaphore(addressConnectionsLimit) }
addressSemaphore.acquire()

try {
aSocket(selector).tcp().connect(address, configuration)
val socket = aSocket(selector).tcp().connect(address, configuration)
LOG.trace { "Successfully connected to address: $address" }
socket
} catch (cause: Throwable) {
// a failure or cancellation
LOG.error("Failed to connect to address: $address", cause)
addressSemaphore.release()
throw cause
}
Expand All @@ -39,8 +60,63 @@ internal class ConnectionFactory(
}
}

fun release(address: InetSocketAddress) {
addressLimit[address]!!.release()
limit.release()
fun release(socket: Socket) {
LOG.trace { "Releasing connection for address: ${socket.remoteAddress}" }

if (socket.isClosed) {
Copy link
Member

@osipxd osipxd Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag seems out of sync with SocketChannel.isOpen on JVM. As a result, we put closed connections to the pool and get a lot of test failures with java.nio.channels.ClosedChannelException

LOG.warn("Attempted to release a closed connection for address: ${socket.remoteAddress}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On JVM sun.nio.ch.SocketChannelImpl.getRemoteAddress calls ensureOpen under the hood, so here will be a crash

return
}

val pooledConnection = PooledConnection(socket, GMTDate().timestamp)
synchronized(this) {
val pool = connectionPool.getOrPut(socket.remoteAddress) { mutableListOf() }
pool.add(pooledConnection)

// If pool size exceeds maxPoolSize, remove and close excess connections
while (pool.size > maxPoolSize) {
LOG.trace { "Pool size exceeded maxPoolSize, removing oldest connection" }
pool.removeAt(0).close()
}

limit.release() // Release the semaphore to allow new connections
addressLimit[socket.remoteAddress]?.release() // Release the address semaphore
}
LOG.trace { "Connection released for address: ${socket.remoteAddress}" }
}

private fun getPooledConnection(address: InetSocketAddress): PooledConnection? {
LOG.trace { "Checking for pooled connection for address: $address" }
val connections = connectionPool[address] ?: return null
val currentTime = GMTDate().timestamp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why we don't use kotlinx.datetime? Are we waiting for a stable version?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found the issue: KTOR-2721 Migrate from GMTDate to kotlinx.datetime.Instant


synchronized(this) {
connections.removeAll {
if (currentTime - it.lastUsed > keepAliveTime) {
LOG.trace { "Removing expired connection for address: $address" }
it.close()
return@removeAll true
}
false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also remove closed connections here? Is it possible to have a closed connection in the pool?

}

if (connections.isNotEmpty()) {
val connection = connections.removeAt(0)
LOG.trace { "Pooled connection found for address: $address" }
return connection
}
}
LOG.trace { "No pooled connection available for address: $address" }
return null
}
}

private class PooledConnection(
val socket: Socket,
var lastUsed: Long
) {
fun close() {
LOG.trace { "Closing socket for address: ${socket.remoteAddress}" }
socket.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ internal class Endpoint(
deliveryPoint.send(task)
}

@OptIn(InternalAPI::class)
private suspend fun makeDedicatedRequest(
request: HttpRequestData,
callContext: CoroutineContext
Expand All @@ -118,7 +117,7 @@ internal class Endpoint(
} catch (cause: Throwable) {
LOGGER.debug("An error occurred while closing connection", cause)
} finally {
releaseConnection()
releaseConnection(connection.socket)
}
}

Expand Down Expand Up @@ -193,7 +192,7 @@ internal class Endpoint(
coroutineContext
)

pipeline.pipelineContext.invokeOnCompletion { releaseConnection() }
pipeline.pipelineContext.invokeOnCompletion { releaseConnection(connection.socket) }
}

@Suppress("UNUSED_EXPRESSION")
Expand Down Expand Up @@ -248,7 +247,7 @@ internal class Endpoint(
} catch (_: Throwable) {
}

connectionFactory.release(address)
connectionFactory.release(socket)
throw cause
}
}
Expand Down Expand Up @@ -288,9 +287,8 @@ internal class Endpoint(
return connectTimeout to socketTimeout
}

private fun releaseConnection() {
val address = connectionAddress ?: return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the field connectionAddress?

connectionFactory.release(address)
private fun releaseConnection(socket: Socket) {
connectionFactory.release(socket)
connections.decrementAndGet()
}

Expand Down