Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JDBC partitioned read support #4959

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.io._
import com.spotify.scio.util.Functions
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.io.jdbc.JdbcIO.{PreparedStatementSetter, StatementPreparator}
import org.apache.beam.sdk.io.jdbc.JdbcIO.{
PreparedStatementSetter,
ReadWithPartitions,
StatementPreparator
}
import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO}
import org.joda.time.Duration
import org.apache.beam.sdk.values.{TypeDescriptor, TypeDescriptors}
import org.joda.time.{DateTime, Duration}

import java.sql.{PreparedStatement, ResultSet, SQLException}
import javax.sql.DataSource
Expand Down Expand Up @@ -154,7 +159,100 @@ final case class JdbcSelect[T: Coder](opts: JdbcConnectionOptions, query: String
}

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] =
throw new UnsupportedOperationException("jdbc.Select is read-only")
throw new UnsupportedOperationException("JdbcSelect is read-only")

override def tap(params: ReadP): Tap[Nothing] =
EmptyTap
}

object JdbcPartitionedRead {

object PartitionColumn {

// Supported types from JdbcUtil.PRESET_HELPERS
def long(
name: String,
upperBound: Option[Long] = None,
lowerBound: Option[Long] = None
): PartitionColumn[java.lang.Long] = new PartitionColumn(
TypeDescriptors.longs(),
name,
upperBound.map(Long.box),
lowerBound.map(Long.box)
)

def dateTime(
name: String,
upperBound: Option[DateTime] = None,
lowerBound: Option[DateTime] = None
): PartitionColumn[DateTime] = new PartitionColumn(
TypeDescriptor.of(classOf[DateTime]),
name,
upperBound,
lowerBound
)
}

case class PartitionColumn[T] private (
typeDescriptor: TypeDescriptor[T],
name: String,
upperBound: Option[T],
lowerBound: Option[T]
)

object ReadParam {
val DefaultNumPartitions: Int = 200
val DefaultDataSourceProviderFn: () => DataSource = null
def defaultConfigOverride[S, T]: ReadWithPartitions[S, T] => ReadWithPartitions[S, T] = identity
}

final case class ReadParam[T, S](
partitionColumn: JdbcPartitionedRead.PartitionColumn[S],
rowMapper: ResultSet => T,
numPartitions: Int = ReadParam.DefaultNumPartitions,
dataSourceProviderFn: () => DataSource = ReadParam.DefaultDataSourceProviderFn,
configOverride: ReadWithPartitions[T, S] => ReadWithPartitions[T, S] =
ReadParam.defaultConfigOverride[T, S]
)
}

final case class JdbcPartitionedRead[T: Coder, S](
opts: JdbcConnectionOptions,
table: String
) extends JdbcIO[T] {
override type ReadP = JdbcPartitionedRead.ReadParam[T, S]
override type WriteP = Nothing
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]
override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(opts, table)})"
override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = {
val coder = CoderMaterializer.beam(sc, Coder[T])
val transform = BJdbcIO
.readWithPartitions[T, S](params.partitionColumn.typeDescriptor)
.withPartitionColumn(params.partitionColumn.name)
.pipe(r => params.partitionColumn.lowerBound.fold(r)(r.withLowerBound))
.pipe(r => params.partitionColumn.upperBound.fold(r)(r.withUpperBound))
.pipe { r =>
if (params.numPartitions != JdbcPartitionedRead.ReadParam.DefaultNumPartitions) {
r.withNumPartitions(params.numPartitions)
} else {
r
}
}
.withCoder(coder)
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(opts))
.withTable(table)
.withRowMapper(params.rowMapper(_))
.pipe { r =>
Option(params.dataSourceProviderFn)
.map(fn => Functions.serializableFn[Void, DataSource](_ => fn()))
.fold(r)(r.withDataSourceProviderFn)
}

sc.applyTransform(params.configOverride(transform))
}

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] =
throw new UnsupportedOperationException("JdbcPartitionRead is read-only")

override def tap(params: ReadP): Tap[Nothing] =
EmptyTap
Expand All @@ -168,7 +266,7 @@ final case class JdbcWrite[T](opts: JdbcConnectionOptions, statement: String) ex
override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(opts, statement)})"

override protected def read(sc: ScioContext, params: ReadP): SCollection[T] =
throw new UnsupportedOperationException("jdbc.Write is write-only")
throw new UnsupportedOperationException("JdbcWrite is write-only")

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = {
val transform = BJdbcIO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ package com.spotify.scio.jdbc.syntax

import com.spotify.scio.ScioContext
import com.spotify.scio.coders.Coder
import com.spotify.scio.jdbc.sharded.{JdbcShardedReadOptions, JdbcShardedSelect}
import com.spotify.scio.jdbc.{JdbcConnectionOptions, JdbcIO, JdbcReadOptions, JdbcSelect}
import com.spotify.scio.jdbc.JdbcIO.ReadParam
import com.spotify.scio.jdbc.JdbcPartitionedRead.PartitionColumn
import com.spotify.scio.jdbc.sharded.{JdbcShardedReadOptions, JdbcShardedSelect, Shard}
import com.spotify.scio.jdbc.{
JdbcConnectionOptions,
JdbcIO,
JdbcPartitionedRead,
JdbcReadOptions,
JdbcSelect
}
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.io.jdbc.JdbcIO.Read
import org.apache.beam.sdk.io.jdbc.JdbcIO.{Read, ReadWithPartitions}

import java.sql.{PreparedStatement, ResultSet}
import javax.sql.DataSource
Expand Down Expand Up @@ -66,11 +72,11 @@ final class JdbcScioContextOps(private val self: ScioContext) extends AnyVal {
def jdbcSelect[T: ClassTag: Coder](
connectionOptions: JdbcConnectionOptions,
query: String,
statementPreparator: PreparedStatement => Unit = ReadParam.DefaultStatementPreparator,
fetchSize: Int = ReadParam.BeamDefaultFetchSize,
outputParallelization: Boolean = ReadParam.DefaultOutputParallelization,
dataSourceProviderFn: () => DataSource = ReadParam.DefaultDataSourceProviderFn,
configOverride: Read[T] => Read[T] = ReadParam.defaultConfigOverride[T]
statementPreparator: PreparedStatement => Unit = JdbcIO.ReadParam.DefaultStatementPreparator,
fetchSize: Int = JdbcIO.ReadParam.BeamDefaultFetchSize,
outputParallelization: Boolean = JdbcIO.ReadParam.DefaultOutputParallelization,
dataSourceProviderFn: () => DataSource = JdbcIO.ReadParam.DefaultDataSourceProviderFn,
configOverride: Read[T] => Read[T] = JdbcIO.ReadParam.defaultConfigOverride[T]
)(rowMapper: ResultSet => T): SCollection[T] =
self.read(JdbcSelect(connectionOptions, query))(
JdbcIO.ReadParam(
Expand Down Expand Up @@ -111,6 +117,25 @@ final class JdbcScioContextOps(private val self: ScioContext) extends AnyVal {
readOptions: JdbcShardedReadOptions[T, S]
): SCollection[T] = self.read(JdbcShardedSelect(readOptions))

def jdbcPartitionedRead[T: Coder, S](
connectionOptions: JdbcConnectionOptions,
table: String,
partitionColumn: PartitionColumn[S],
numPartitions: Int = JdbcPartitionedRead.ReadParam.DefaultNumPartitions,
dataSourceProviderFn: () => DataSource = JdbcIO.ReadParam.DefaultDataSourceProviderFn,
configOverride: ReadWithPartitions[T, S] => ReadWithPartitions[T, S] =
JdbcPartitionedRead.ReadParam.defaultConfigOverride[T, S]
)(rowMapper: ResultSet => T): SCollection[T] = {
val params = JdbcPartitionedRead.ReadParam(
partitionColumn,
rowMapper,
numPartitions,
dataSourceProviderFn,
configOverride
)
self.read(JdbcPartitionedRead[T, S](connectionOptions, table))(params)
}

}
trait ScioContextSyntax {
implicit def jdbcScioContextOps(sc: ScioContext): JdbcScioContextOps = new JdbcScioContextOps(sc)
Expand Down
72 changes: 60 additions & 12 deletions scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@

package com.spotify.scio.jdbc

import java.sql.ResultSet

import com.spotify.scio._
import org.apache.beam.sdk.io.{jdbc => beam}
import com.spotify.scio.io.TextIO
import com.spotify.scio.jdbc.JdbcPartitionedRead.PartitionColumn
import com.spotify.scio.testing._
import org.apache.beam.sdk.io.{jdbc => beam}
import org.apache.beam.sdk.values.TypeDescriptors

import java.sql.ResultSet

object JdbcJob {

val query = "SELECT <this> FROM <this>"
val statement = "INSERT INTO <this> VALUES( ?, ? ..?)"
val Query = "SELECT <this> FROM <this>"
val Statement = "INSERT INTO <this> VALUES( ?, ? ..?)"
def main(cmdlineArgs: Array[String]): Unit = {
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](cmdlineArgs)
val sc = ScioContext(opts)
val connectionOpts = getConnectionOptions(opts)
sc.jdbcSelect[String](connectionOpts, query)((rs: ResultSet) => rs.getString(1))
sc.jdbcSelect[String](connectionOpts, Query)((rs: ResultSet) => rs.getString(1))
.map(_ + "J")
.saveAsJdbc(connectionOpts, statement) { (_, _) => }
.saveAsJdbc(connectionOpts, Statement) { (_, _) => }
sc.run()
()
}
Expand All @@ -52,6 +55,26 @@ object JdbcJob {
)
}

object JdbcPartitionedJob {

val Table = "table"
val IdColumn = PartitionColumn.long("id")

val OutputPath = "output"

def main(cmdlineArgs: Array[String]): Unit = {
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](cmdlineArgs)
val sc = ScioContext(opts)
val connectionOpts = JdbcJob.getConnectionOptions(opts)

sc.jdbcPartitionedRead(connectionOpts, Table, IdColumn)((rs: ResultSet) => rs.getString(1))
.map(_ + "J")
.saveAsTextFile(OutputPath)
sc.run()
()
}
}

class JdbcTest extends PipelineSpec {
def testJdbc(xs: String*): Unit = {
val args = Seq(
Expand All @@ -65,8 +88,8 @@ class JdbcTest extends PipelineSpec {

JobTest[JdbcJob.type]
.args(args: _*)
.input(JdbcIO[String](connectionOpts, JdbcJob.query), Seq("a", "b", "c"))
.output(JdbcIO[String](connectionOpts, JdbcJob.statement))(coll =>
.input(JdbcIO[String](connectionOpts, JdbcJob.Query), Seq("a", "b", "c"))
.output(JdbcIO[String](connectionOpts, JdbcJob.Statement))(coll =>
coll should containInAnyOrder(xs)
)
.run()
Expand All @@ -81,7 +104,7 @@ class JdbcTest extends PipelineSpec {
an[AssertionError] should be thrownBy { testJdbc("aJ", "bJ", "cJ", "dJ") }
}

it should "connnect via JDBC without a password" in {
it should "identify JDBC IOs from connection options and query" in {
val args = Seq(
"--cloudSqlUsername=john",
"--cloudSqlDb=mydb",
Expand All @@ -94,8 +117,33 @@ class JdbcTest extends PipelineSpec {

JobTest[JdbcJob.type]
.args(args: _*)
.input(JdbcIO[String](connectionOpts, JdbcJob.query), Seq("a", "b", "c"))
.output(JdbcIO[String](connectionOpts, JdbcJob.statement))(coll =>
.input(JdbcIO[String](connectionOpts, JdbcJob.Query), Seq("a", "b", "c"))
.output(JdbcIO[String](connectionOpts, JdbcJob.Statement))(coll =>
coll should containInAnyOrder(expected)
)
.run()
}

it should "identify JDBC partitioned read from connection options and table" in {
val args = Seq(
"--cloudSqlUsername=john",
"--cloudSqlDb=mydb",
"--cloudSqlInstanceConnectionName=project-id:zone:db-instance-name"
)
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args.toArray)
val connectionOpts = JdbcJob.getConnectionOptions(opts)

val expected = Seq("aJ", "bJ", "cJ")

JdbcPartitionedJob.IdColumn.typeDescriptor shouldBe TypeDescriptors.longs()

JobTest[JdbcPartitionedJob.type]
.args(args: _*)
.input(
JdbcIO[String](connectionOpts, JdbcPartitionedJob.Table),
Seq("a", "b", "c")
)
.output(TextIO(JdbcPartitionedJob.OutputPath))(coll =>
coll should containInAnyOrder(expected)
)
.run()
Expand Down