Skip to content

Commit 52df0cd

Browse files
anishshri-dbHeartSaVioR
authored andcommitted
[SPARK-50162][SS][TESTS] Add tests for loading snapshot with given version for transformWithState operator state and state data source reader
### What changes were proposed in this pull request? Add tests for loading snapshot with given version for transformWithState operator state and state data source reader ### Why are the changes needed? To add test coverage for snapshotStartBatchId integration of tws and state data source reader ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test only change Added unit tests ``` ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.datasources.v2.state.StateDataSourceTransformWithStateSuite, threads: ForkJoinPool.commonPool-worker-6 (daemon=true), ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 (daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), ForkJoinPool.commonPool-worker-2 (daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPo... [info] Run completed in 2 minutes, 5 seconds. [info] Total number of tests run: 23 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 23, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48710 from anishshri-db/task/SPARK-50162. Authored-by: Anish Shrigondekar <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 667cff7 commit 52df0cd

File tree

2 files changed

+160
-10
lines changed

2 files changed

+160
-10
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -457,19 +457,19 @@ class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
457457
testSnapshotPartitionId()
458458
}
459459

460-
test("snapshotStatBatchId on limit state") {
460+
test("snapshotStartBatchId on limit state") {
461461
testSnapshotOnLimitState("hdfs")
462462
}
463463

464-
test("snapshotStatBatchId on aggregation state") {
464+
test("snapshotStartBatchId on aggregation state") {
465465
testSnapshotOnAggregateState("hdfs")
466466
}
467467

468-
test("snapshotStatBatchId on deduplication state") {
468+
test("snapshotStartBatchId on deduplication state") {
469469
testSnapshotOnDeduplicateState("hdfs")
470470
}
471471

472-
test("snapshotStatBatchId on join state") {
472+
test("snapshotStartBatchId on join state") {
473473
testSnapshotOnJoinState("hdfs", 1)
474474
testSnapshotOnJoinState("hdfs", 2)
475475
}
@@ -550,19 +550,19 @@ StateDataSourceReadSuite {
550550
testSnapshotPartitionId()
551551
}
552552

553-
test("snapshotStatBatchId on limit state") {
553+
test("snapshotStartBatchId on limit state") {
554554
testSnapshotOnLimitState("rocksdb")
555555
}
556556

557-
test("snapshotStatBatchId on aggregation state") {
557+
test("snapshotStartBatchId on aggregation state") {
558558
testSnapshotOnAggregateState("rocksdb")
559559
}
560560

561-
test("snapshotStatBatchId on deduplication state") {
561+
test("snapshotStartBatchId on deduplication state") {
562562
testSnapshotOnDeduplicateState("rocksdb")
563563
}
564564

565-
test("snapshotStatBatchId on join state") {
565+
test("snapshotStartBatchId on join state") {
566566
testSnapshotOnJoinState("rocksdb", 1)
567567
testSnapshotOnJoinState("rocksdb", 2)
568568
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state
1818

19+
import java.io.File
1920
import java.time.Duration
2021

22+
import org.apache.hadoop.conf.Configuration
23+
24+
import org.apache.spark.io.CompressionCodec
2125
import org.apache.spark.sql.{Encoders, Row}
2226
import org.apache.spark.sql.execution.streaming.MemoryStream
23-
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass}
24-
import org.apache.spark.sql.functions.{explode, timestamp_seconds}
27+
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass}
28+
import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
2529
import org.apache.spark.sql.internal.SQLConf
2630
import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
2731
import org.apache.spark.sql.streaming.util.StreamManualClock
32+
import org.apache.spark.util.Utils
2833

2934
/** Stateful processor of single value state var with non-primitive type */
3035
class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor {
@@ -997,4 +1002,149 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
9971002
}
9981003
}
9991004
}
1005+
1006+
/**
1007+
* Note that we cannot use the golden files approach for transformWithState. The new schema
1008+
* format keeps track of the schema file path as an absolute path which cannot be used with
1009+
* the getResource model used in other similar tests. Hence, we force the snapshot creation
1010+
* for given versions and ensure that we are loading from given start snapshot version for loading
1011+
* the state data.
1012+
*/
1013+
testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") {
1014+
class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] {
1015+
@transient protected var _countState: ValueState[Long] = _
1016+
1017+
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
1018+
_countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong,
1019+
TTLConfig.NONE)
1020+
}
1021+
1022+
override def handleInputRows(
1023+
key: Int,
1024+
inputRows: Iterator[(Int, Long)],
1025+
timerValues: TimerValues): Iterator[(Int, Long)] = {
1026+
val count = _countState.getOption().getOrElse(0L)
1027+
var totalSum = 0L
1028+
inputRows.foreach { entry =>
1029+
totalSum += entry._2
1030+
}
1031+
_countState.update(count + totalSum)
1032+
Iterator((key, count + totalSum))
1033+
}
1034+
}
1035+
1036+
withTempDir { tmpDir =>
1037+
withSQLConf(
1038+
SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
1039+
classOf[RocksDBStateStoreProvider].getName,
1040+
SQLConf.SHUFFLE_PARTITIONS.key ->
1041+
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
1042+
SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
1043+
SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
1044+
val inputData = MemoryStream[(Int, Long)]
1045+
val query = inputData
1046+
.toDS()
1047+
.groupByKey(_._1)
1048+
.transformWithState(new AggregationStatefulProcessor(),
1049+
TimeMode.None(),
1050+
OutputMode.Append())
1051+
testStream(query)(
1052+
StartStream(checkpointLocation = tmpDir.getCanonicalPath),
1053+
AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L)),
1054+
ProcessAllAvailable(),
1055+
AddData(inputData, (5, 1L), (6, 2L), (7, 3L), (8, 4L)),
1056+
ProcessAllAvailable(),
1057+
AddData(inputData, (9, 1L), (10, 2L), (11, 3L), (12, 4L)),
1058+
ProcessAllAvailable(),
1059+
AddData(inputData, (13, 1L), (14, 2L), (15, 3L), (16, 4L)),
1060+
ProcessAllAvailable(),
1061+
AddData(inputData, (17, 1L), (18, 2L), (19, 3L), (20, 4L)),
1062+
ProcessAllAvailable(),
1063+
// Ensure that we get a chance to upload created snapshots
1064+
Execute { _ => Thread.sleep(5000) },
1065+
StopStream
1066+
)
1067+
}
1068+
1069+
// Create a file manager for the state store with opId=0 and partition=4
1070+
val dfsRootDir = new File(tmpDir.getAbsolutePath + "/state/0/4")
1071+
val fileManager = new RocksDBFileManager(
1072+
dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration,
1073+
CompressionCodec.LZ4)
1074+
1075+
// Read the changelog for one of the partitions at version 3 and
1076+
// ensure that we have two entries
1077+
// For this test - keys 9 and 12 are written at version 3 for partition 4
1078+
val changelogReader = fileManager.getChangelogReader(3, true)
1079+
val entries = changelogReader.toSeq
1080+
assert(entries.size == 2)
1081+
val retainEntry = entries.head
1082+
1083+
// Retain one of the entries and delete the changelog file
1084+
val changelogFilePath = dfsRootDir.getAbsolutePath + "/3.changelog"
1085+
Utils.deleteRecursively(new File(changelogFilePath))
1086+
1087+
// Write the retained entry back to the changelog
1088+
val changelogWriter = fileManager.getChangeLogWriter(3, true)
1089+
changelogWriter.put(retainEntry._2, retainEntry._3)
1090+
changelogWriter.commit()
1091+
1092+
// Ensure that we have only one entry in the changelog for version 3
1093+
// For this test - key 9 is retained and key 12 is deleted
1094+
val changelogReader1 = fileManager.getChangelogReader(3, true)
1095+
val entries1 = changelogReader1.toSeq
1096+
assert(entries1.size == 1)
1097+
1098+
// Ensure that the state matches for the partition that is not modified and does not match for
1099+
// the other partition
1100+
Seq(1, 4).foreach { partition =>
1101+
val stateSnapshotDf = spark
1102+
.read
1103+
.format("statestore")
1104+
.option("snapshotPartitionId", partition)
1105+
.option("snapshotStartBatchId", 1)
1106+
.option("stateVarName", "countState")
1107+
.load(tmpDir.getCanonicalPath)
1108+
1109+
val stateDf = spark
1110+
.read
1111+
.format("statestore")
1112+
.option("stateVarName", "countState")
1113+
.load(tmpDir.getCanonicalPath)
1114+
.filter(col("partition_id") === partition)
1115+
1116+
if (partition == 1) {
1117+
checkAnswer(stateSnapshotDf, stateDf)
1118+
} else {
1119+
// Ensure that key 12 is not present in the final state loaded from given snapshot
1120+
val resultDfForSnapshot = stateSnapshotDf.selectExpr(
1121+
"key.value AS groupingKey",
1122+
"value.value AS count",
1123+
"partition_id")
1124+
checkAnswer(resultDfForSnapshot,
1125+
Seq(Row(16, 4L, 4),
1126+
Row(17, 1L, 4),
1127+
Row(19, 3L, 4),
1128+
Row(2, 2L, 4),
1129+
Row(6, 2L, 4),
1130+
Row(9, 1L, 4)))
1131+
1132+
// Ensure that key 12 is present in the final state loaded from the latest snapshot
1133+
val resultDf = stateDf.selectExpr(
1134+
"key.value AS groupingKey",
1135+
"value.value AS count",
1136+
"partition_id")
1137+
1138+
checkAnswer(resultDf,
1139+
Seq(Row(16, 4L, 4),
1140+
Row(17, 1L, 4),
1141+
Row(19, 3L, 4),
1142+
Row(2, 2L, 4),
1143+
Row(6, 2L, 4),
1144+
Row(9, 1L, 4),
1145+
Row(12, 4L, 4)))
1146+
}
1147+
}
1148+
}
1149+
}
10001150
}

0 commit comments

Comments
 (0)