Skip to content

Commit

Permalink
[SPARK-50162][SS][TESTS] Add tests for loading snapshot with given ve…
Browse files Browse the repository at this point in the history
…rsion for transformWithState operator state and state data source reader

### What changes were proposed in this pull request?
Add tests for loading snapshot with given version for transformWithState operator state and state data source reader

### Why are the changes needed?
To add test coverage for snapshotStartBatchId integration of tws and state data source reader

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Test only change

Added unit tests

```
===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.datasources.v2.state.StateDataSourceTransformWithStateSuite, threads: ForkJoinPool.commonPool-worker-6 (daemon=true), ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 (daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), ForkJoinPool.commonPool-worker-2 (daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPo...
[info] Run completed in 2 minutes, 5 seconds.
[info] Total number of tests run: 23
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 23, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #48710 from anishshri-db/task/SPARK-50162.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Nov 6, 2024
1 parent 667cff7 commit 52df0cd
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -457,19 +457,19 @@ class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
testSnapshotPartitionId()
}

test("snapshotStatBatchId on limit state") {
test("snapshotStartBatchId on limit state") {
testSnapshotOnLimitState("hdfs")
}

test("snapshotStatBatchId on aggregation state") {
test("snapshotStartBatchId on aggregation state") {
testSnapshotOnAggregateState("hdfs")
}

test("snapshotStatBatchId on deduplication state") {
test("snapshotStartBatchId on deduplication state") {
testSnapshotOnDeduplicateState("hdfs")
}

test("snapshotStatBatchId on join state") {
test("snapshotStartBatchId on join state") {
testSnapshotOnJoinState("hdfs", 1)
testSnapshotOnJoinState("hdfs", 2)
}
Expand Down Expand Up @@ -550,19 +550,19 @@ StateDataSourceReadSuite {
testSnapshotPartitionId()
}

test("snapshotStatBatchId on limit state") {
test("snapshotStartBatchId on limit state") {
testSnapshotOnLimitState("rocksdb")
}

test("snapshotStatBatchId on aggregation state") {
test("snapshotStartBatchId on aggregation state") {
testSnapshotOnAggregateState("rocksdb")
}

test("snapshotStatBatchId on deduplication state") {
test("snapshotStartBatchId on deduplication state") {
testSnapshotOnDeduplicateState("rocksdb")
}

test("snapshotStatBatchId on join state") {
test("snapshotStartBatchId on join state") {
testSnapshotOnJoinState("rocksdb", 1)
testSnapshotOnJoinState("rocksdb", 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.io.File
import java.time.Duration

import org.apache.hadoop.conf.Configuration

import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass}
import org.apache.spark.sql.functions.{explode, timestamp_seconds}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass}
import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.Utils

/** Stateful processor of single value state var with non-primitive type */
class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor {
Expand Down Expand Up @@ -997,4 +1002,149 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
}
}
}

/**
* Note that we cannot use the golden files approach for transformWithState. The new schema
* format keeps track of the schema file path as an absolute path which cannot be used with
* the getResource model used in other similar tests. Hence, we force the snapshot creation
* for given versions and ensure that we are loading from given start snapshot version for loading
* the state data.
*/
testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") {
class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] {
@transient protected var _countState: ValueState[Long] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong,
TTLConfig.NONE)
}

override def handleInputRows(
key: Int,
inputRows: Iterator[(Int, Long)],
timerValues: TimerValues): Iterator[(Int, Long)] = {
val count = _countState.getOption().getOrElse(0L)
var totalSum = 0L
inputRows.foreach { entry =>
totalSum += entry._2
}
_countState.update(count + totalSum)
Iterator((key, count + totalSum))
}
}

withTempDir { tmpDir =>
withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
val inputData = MemoryStream[(Int, Long)]
val query = inputData
.toDS()
.groupByKey(_._1)
.transformWithState(new AggregationStatefulProcessor(),
TimeMode.None(),
OutputMode.Append())
testStream(query)(
StartStream(checkpointLocation = tmpDir.getCanonicalPath),
AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L)),
ProcessAllAvailable(),
AddData(inputData, (5, 1L), (6, 2L), (7, 3L), (8, 4L)),
ProcessAllAvailable(),
AddData(inputData, (9, 1L), (10, 2L), (11, 3L), (12, 4L)),
ProcessAllAvailable(),
AddData(inputData, (13, 1L), (14, 2L), (15, 3L), (16, 4L)),
ProcessAllAvailable(),
AddData(inputData, (17, 1L), (18, 2L), (19, 3L), (20, 4L)),
ProcessAllAvailable(),
// Ensure that we get a chance to upload created snapshots
Execute { _ => Thread.sleep(5000) },
StopStream
)
}

// Create a file manager for the state store with opId=0 and partition=4
val dfsRootDir = new File(tmpDir.getAbsolutePath + "/state/0/4")
val fileManager = new RocksDBFileManager(
dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration,
CompressionCodec.LZ4)

// Read the changelog for one of the partitions at version 3 and
// ensure that we have two entries
// For this test - keys 9 and 12 are written at version 3 for partition 4
val changelogReader = fileManager.getChangelogReader(3, true)
val entries = changelogReader.toSeq
assert(entries.size == 2)
val retainEntry = entries.head

// Retain one of the entries and delete the changelog file
val changelogFilePath = dfsRootDir.getAbsolutePath + "/3.changelog"
Utils.deleteRecursively(new File(changelogFilePath))

// Write the retained entry back to the changelog
val changelogWriter = fileManager.getChangeLogWriter(3, true)
changelogWriter.put(retainEntry._2, retainEntry._3)
changelogWriter.commit()

// Ensure that we have only one entry in the changelog for version 3
// For this test - key 9 is retained and key 12 is deleted
val changelogReader1 = fileManager.getChangelogReader(3, true)
val entries1 = changelogReader1.toSeq
assert(entries1.size == 1)

// Ensure that the state matches for the partition that is not modified and does not match for
// the other partition
Seq(1, 4).foreach { partition =>
val stateSnapshotDf = spark
.read
.format("statestore")
.option("snapshotPartitionId", partition)
.option("snapshotStartBatchId", 1)
.option("stateVarName", "countState")
.load(tmpDir.getCanonicalPath)

val stateDf = spark
.read
.format("statestore")
.option("stateVarName", "countState")
.load(tmpDir.getCanonicalPath)
.filter(col("partition_id") === partition)

if (partition == 1) {
checkAnswer(stateSnapshotDf, stateDf)
} else {
// Ensure that key 12 is not present in the final state loaded from given snapshot
val resultDfForSnapshot = stateSnapshotDf.selectExpr(
"key.value AS groupingKey",
"value.value AS count",
"partition_id")
checkAnswer(resultDfForSnapshot,
Seq(Row(16, 4L, 4),
Row(17, 1L, 4),
Row(19, 3L, 4),
Row(2, 2L, 4),
Row(6, 2L, 4),
Row(9, 1L, 4)))

// Ensure that key 12 is present in the final state loaded from the latest snapshot
val resultDf = stateDf.selectExpr(
"key.value AS groupingKey",
"value.value AS count",
"partition_id")

checkAnswer(resultDf,
Seq(Row(16, 4L, 4),
Row(17, 1L, 4),
Row(19, 3L, 4),
Row(2, 2L, 4),
Row(6, 2L, 4),
Row(9, 1L, 4),
Row(12, 4L, 4)))
}
}
}
}
}

0 comments on commit 52df0cd

Please sign in to comment.