Skip to content

Commit

Permalink
feat: Programatic connection factory options (#587)
Browse files Browse the repository at this point in the history
* Possibility to amend the ConnectionFactoryOptions
* config in dialect section since it's not supported by H2
  • Loading branch information
patriknw authored Aug 6, 2024
1 parent 7e01539 commit cf04b9f
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 16 deletions.
12 changes: 12 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ akka.persistence.r2dbc {
# This timeout is handled by the database server.
# This timeout should be less than `close-calls-exceeding`.
statement-timeout = off

# Possibility to programatically amend the ConnectionFactoryOptions.
# Enable by specifying the fully qualified class name of a
# `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`.
# The class can optionally have a constructor with an ActorSystem parameter.
options-provider = ""
// #connection-settings-postgres
// #connection-settings-yugabyte
}
Expand Down Expand Up @@ -413,6 +419,12 @@ akka.persistence.r2dbc {
# Used to encode tags to and from db. Tags must not contain this separator.
tag-separator = ","

# Possibility to programatically amend the ConnectionFactoryOptions.
# Enable by specifying the fully qualified class name of a
# `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`.
# The class can optionally have a constructor with an ActorSystem parameter.
options-provider = ""

// #connection-settings-sqlserver
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Failure
import scala.util.Success

import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.annotation.InternalApi
import akka.annotation.InternalStableApi
Expand All @@ -29,9 +34,24 @@ object ConnectionFactoryProvider extends ExtensionId[ConnectionFactoryProvider]

// Java API
def get(system: ActorSystem[_]): ConnectionFactoryProvider = apply(system)

trait ConnectionFactoryOptionsProvider {
def buildOptions(
builder: ConnectionFactoryOptions.Builder,
connectionFactoryConfig: Config): ConnectionFactoryOptions
}

private object DefaultConnectionFactoryOptionsProvider extends ConnectionFactoryOptionsProvider {
override def buildOptions(
builder: ConnectionFactoryOptions.Builder,
connectionFactoryConfig: Config): ConnectionFactoryOptions =
builder.build()
}
}

class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
import ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import ConnectionFactoryProvider.DefaultConnectionFactoryOptionsProvider

import R2dbcExecutor.PublisherOps
private val sessions = new ConcurrentHashMap[String, ConnectionPool]
Expand All @@ -51,7 +71,8 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
configLocation,
configLocation => {
val settings = connectionFactorySettingsFor(configLocation)
val connectionFactory = settings.dialect.createConnectionFactory(settings.config)
val optionsProvider = connectionFactoryOptionsProvider(settings)
val connectionFactory = settings.dialect.createConnectionFactory(settings.config, optionsProvider)
createConnectionPoolFactory(settings.poolSettings, connectionFactory)
})
.asInstanceOf[ConnectionFactory]
Expand All @@ -72,6 +93,22 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
}
}

private def connectionFactoryOptionsProvider(
settings: ConnectionFactorySettings): ConnectionFactoryOptionsProvider = {
settings.optionsProvider match {
case "" => DefaultConnectionFactoryOptionsProvider
case fqcn =>
system.dynamicAccess.createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, Nil) match {
case Success(provider) => provider
case Failure(_) =>
system.dynamicAccess
.createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, List(classOf[ActorSystem[_]] -> system))
.get

}
}
}

/**
* INTERNAL API
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import com.typesafe.config.Config
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -44,7 +46,10 @@ private[r2dbc] object ConnectionFactorySettings {
// for backwards compatibility/convenience
val poolSettings = new ConnectionPoolSettings(config)

ConnectionFactorySettings(dialect, config, poolSettings)
// H2 dialect doesn't support options-provider
val optionsProvider = if (dialect == H2Dialect) "" else config.getString("options-provider")

ConnectionFactorySettings(dialect, config, poolSettings, optionsProvider)
}

}
Expand All @@ -56,4 +61,5 @@ private[r2dbc] object ConnectionFactorySettings {
private[r2dbc] case class ConnectionFactorySettings(
dialect: Dialect,
config: Config,
poolSettings: ConnectionPoolSettings)
poolSettings: ConnectionPoolSettings,
optionsProvider: String)
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import akka.annotation.InternalStableApi
import akka.persistence.r2dbc.R2dbcSettings
import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactory

import scala.concurrent.ExecutionContext

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider

/**
* INTERNAL API
*/
Expand All @@ -31,7 +32,7 @@ private[r2dbc] trait Dialect {

def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext

def createConnectionFactory(config: Config): ConnectionFactory
def createConnectionFactory(config: Config, optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory

def createJournalDao(executorProvider: R2dbcExecutorProvider): JournalDao

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.Locale

import scala.concurrent.ExecutionContext

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.persistence.r2dbc.internal.codec.IdentityAdapter
import akka.persistence.r2dbc.internal.codec.QueryAdapter
Expand All @@ -46,7 +47,9 @@ private[r2dbc] object H2Dialect extends Dialect {
res
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {
// starting point for both url and regular configs,
// to allow url to override anything but provide sane defaults
val builder = H2ConnectionConfiguration.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand Down Expand Up @@ -66,7 +67,9 @@ private[r2dbc] object PostgresDialect extends Dialect {
}
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {
val settings = new PostgresConnectionFactorySettings(config)
val builder =
settings.urlOption match {
Expand Down Expand Up @@ -115,7 +118,8 @@ private[r2dbc] object PostgresDialect extends Dialect {
builder.option(PostgresqlConnectionFactoryProvider.SSL_PASSWORD, settings.sslPassword)
}

ConnectionFactories.get(builder.build())
val options = optionsProvider.buildOptions(builder, config)
ConnectionFactories.get(options)
}

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import akka.persistence.r2dbc.internal.SnapshotDao
import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactory

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand All @@ -27,8 +28,10 @@ private[r2dbc] object YugabyteDialect extends Dialect {

override def name: String = "yugabyte"

override def createConnectionFactory(config: Config): ConnectionFactory =
PostgresDialect.createConnectionFactory(config)
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory =
PostgresDialect.createConnectionFactory(config, optionsProvider)

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
system.executionContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand Down Expand Up @@ -59,7 +60,9 @@ private[r2dbc] object SqlServerDialect extends Dialect {
res
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {

val settings = new SqlServerConnectionFactorySettings(config)
val builder =
Expand All @@ -79,11 +82,13 @@ private[r2dbc] object SqlServerDialect extends Dialect {
.option(ConnectionFactoryOptions.DATABASE, settings.database)
.option(ConnectionFactoryOptions.CONNECT_TIMEOUT, JDuration.ofMillis(settings.connectTimeout.toMillis))
}
ConnectionFactories.get(
builder
//the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276
.option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false)
.build())

builder
//the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276
.option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false)

val options = optionsProvider.buildOptions(builder, config)
ConnectionFactories.get(options)
}

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
Expand Down
49 changes: 49 additions & 0 deletions core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers {
connectionFactorySettings.sslMode shouldBe "verify-full"
SSLMode.fromValue(connectionFactorySettings.sslMode) shouldBe SSLMode.VERIFY_FULL
}

"support options-provider" in {
val config = ConfigFactory
.parseString("akka.persistence.r2dbc.connection-factory.options-provider=my.OptProvider")
.withFallback(ConfigFactory.load("application-postgres.conf"))
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider"
}
}

"data-partition settings" should {
Expand Down Expand Up @@ -287,5 +295,46 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers {
settings.connectionFactorSliceRanges(0) should be(0 until 1024)
}

"support options-provider" in {
val config = ConfigFactory
.parseString("""
akka.persistence.r2dbc.postgres.options-provider=my.OptProvider
akka.persistence.r2dbc.data-partition {
number-of-partitions = 2
number-of-databases = 2
}
akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-0-0.host = hostA
akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-1-1.host = hostB
""")
.withFallback(ConfigFactory.load("application-postgres.conf"))
.resolve()
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider"
settings.connectionFactorySettings(1).optionsProvider shouldBe "my.OptProvider"
}

"support options-provider per db" in {
val config = ConfigFactory
.parseString("""
akka.persistence.r2dbc.data-partition {
number-of-partitions = 2
number-of-databases = 2
}
akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-0-0.host = hostA
akka.persistence.r2dbc.connection-factory-0-0.options-provider=my.OptProvider0
akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-1-1.host = hostB
akka.persistence.r2dbc.connection-factory-1-1.options-provider=my.OptProvider1
""")
.withFallback(ConfigFactory.load("application-postgres.conf"))
.resolve()
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider0"
settings.connectionFactorySettings(1023).optionsProvider shouldBe "my.OptProvider1"
}

}
}

0 comments on commit cf04b9f

Please sign in to comment.