|
16 | 16 | */
|
17 | 17 | package org.apache.spark.sql.execution.datasources.v2.state
|
18 | 18 |
|
| 19 | +import java.io.File |
19 | 20 | import java.time.Duration
|
20 | 21 |
|
| 22 | +import org.apache.hadoop.conf.Configuration |
| 23 | + |
| 24 | +import org.apache.spark.io.CompressionCodec |
21 | 25 | import org.apache.spark.sql.{Encoders, Row}
|
22 | 26 | import org.apache.spark.sql.execution.streaming.MemoryStream
|
23 |
| -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} |
24 |
| -import org.apache.spark.sql.functions.{explode, timestamp_seconds} |
| 27 | +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBFileManager, RocksDBStateStoreProvider, TestClass} |
| 28 | +import org.apache.spark.sql.functions.{col, explode, timestamp_seconds} |
25 | 29 | import org.apache.spark.sql.internal.SQLConf
|
26 | 30 | import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
|
27 | 31 | import org.apache.spark.sql.streaming.util.StreamManualClock
|
| 32 | +import org.apache.spark.util.Utils |
28 | 33 |
|
29 | 34 | /** Stateful processor of single value state var with non-primitive type */
|
30 | 35 | class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor {
|
@@ -997,4 +1002,149 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
|
997 | 1002 | }
|
998 | 1003 | }
|
999 | 1004 | }
|
| 1005 | + |
| 1006 | + /** |
| 1007 | + * Note that we cannot use the golden files approach for transformWithState. The new schema |
| 1008 | + * format keeps track of the schema file path as an absolute path which cannot be used with |
| 1009 | + * the getResource model used in other similar tests. Hence, we force the snapshot creation |
| 1010 | + * for given versions and ensure that we are loading from given start snapshot version for loading |
| 1011 | + * the state data. |
| 1012 | + */ |
| 1013 | + testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") { |
| 1014 | + class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] { |
| 1015 | + @transient protected var _countState: ValueState[Long] = _ |
| 1016 | + |
| 1017 | + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { |
| 1018 | + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, |
| 1019 | + TTLConfig.NONE) |
| 1020 | + } |
| 1021 | + |
| 1022 | + override def handleInputRows( |
| 1023 | + key: Int, |
| 1024 | + inputRows: Iterator[(Int, Long)], |
| 1025 | + timerValues: TimerValues): Iterator[(Int, Long)] = { |
| 1026 | + val count = _countState.getOption().getOrElse(0L) |
| 1027 | + var totalSum = 0L |
| 1028 | + inputRows.foreach { entry => |
| 1029 | + totalSum += entry._2 |
| 1030 | + } |
| 1031 | + _countState.update(count + totalSum) |
| 1032 | + Iterator((key, count + totalSum)) |
| 1033 | + } |
| 1034 | + } |
| 1035 | + |
| 1036 | + withTempDir { tmpDir => |
| 1037 | + withSQLConf( |
| 1038 | + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> |
| 1039 | + classOf[RocksDBStateStoreProvider].getName, |
| 1040 | + SQLConf.SHUFFLE_PARTITIONS.key -> |
| 1041 | + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, |
| 1042 | + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", |
| 1043 | + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") { |
| 1044 | + val inputData = MemoryStream[(Int, Long)] |
| 1045 | + val query = inputData |
| 1046 | + .toDS() |
| 1047 | + .groupByKey(_._1) |
| 1048 | + .transformWithState(new AggregationStatefulProcessor(), |
| 1049 | + TimeMode.None(), |
| 1050 | + OutputMode.Append()) |
| 1051 | + testStream(query)( |
| 1052 | + StartStream(checkpointLocation = tmpDir.getCanonicalPath), |
| 1053 | + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L)), |
| 1054 | + ProcessAllAvailable(), |
| 1055 | + AddData(inputData, (5, 1L), (6, 2L), (7, 3L), (8, 4L)), |
| 1056 | + ProcessAllAvailable(), |
| 1057 | + AddData(inputData, (9, 1L), (10, 2L), (11, 3L), (12, 4L)), |
| 1058 | + ProcessAllAvailable(), |
| 1059 | + AddData(inputData, (13, 1L), (14, 2L), (15, 3L), (16, 4L)), |
| 1060 | + ProcessAllAvailable(), |
| 1061 | + AddData(inputData, (17, 1L), (18, 2L), (19, 3L), (20, 4L)), |
| 1062 | + ProcessAllAvailable(), |
| 1063 | + // Ensure that we get a chance to upload created snapshots |
| 1064 | + Execute { _ => Thread.sleep(5000) }, |
| 1065 | + StopStream |
| 1066 | + ) |
| 1067 | + } |
| 1068 | + |
| 1069 | + // Create a file manager for the state store with opId=0 and partition=4 |
| 1070 | + val dfsRootDir = new File(tmpDir.getAbsolutePath + "/state/0/4") |
| 1071 | + val fileManager = new RocksDBFileManager( |
| 1072 | + dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration, |
| 1073 | + CompressionCodec.LZ4) |
| 1074 | + |
| 1075 | + // Read the changelog for one of the partitions at version 3 and |
| 1076 | + // ensure that we have two entries |
| 1077 | + // For this test - keys 9 and 12 are written at version 3 for partition 4 |
| 1078 | + val changelogReader = fileManager.getChangelogReader(3, true) |
| 1079 | + val entries = changelogReader.toSeq |
| 1080 | + assert(entries.size == 2) |
| 1081 | + val retainEntry = entries.head |
| 1082 | + |
| 1083 | + // Retain one of the entries and delete the changelog file |
| 1084 | + val changelogFilePath = dfsRootDir.getAbsolutePath + "/3.changelog" |
| 1085 | + Utils.deleteRecursively(new File(changelogFilePath)) |
| 1086 | + |
| 1087 | + // Write the retained entry back to the changelog |
| 1088 | + val changelogWriter = fileManager.getChangeLogWriter(3, true) |
| 1089 | + changelogWriter.put(retainEntry._2, retainEntry._3) |
| 1090 | + changelogWriter.commit() |
| 1091 | + |
| 1092 | + // Ensure that we have only one entry in the changelog for version 3 |
| 1093 | + // For this test - key 9 is retained and key 12 is deleted |
| 1094 | + val changelogReader1 = fileManager.getChangelogReader(3, true) |
| 1095 | + val entries1 = changelogReader1.toSeq |
| 1096 | + assert(entries1.size == 1) |
| 1097 | + |
| 1098 | + // Ensure that the state matches for the partition that is not modified and does not match for |
| 1099 | + // the other partition |
| 1100 | + Seq(1, 4).foreach { partition => |
| 1101 | + val stateSnapshotDf = spark |
| 1102 | + .read |
| 1103 | + .format("statestore") |
| 1104 | + .option("snapshotPartitionId", partition) |
| 1105 | + .option("snapshotStartBatchId", 1) |
| 1106 | + .option("stateVarName", "countState") |
| 1107 | + .load(tmpDir.getCanonicalPath) |
| 1108 | + |
| 1109 | + val stateDf = spark |
| 1110 | + .read |
| 1111 | + .format("statestore") |
| 1112 | + .option("stateVarName", "countState") |
| 1113 | + .load(tmpDir.getCanonicalPath) |
| 1114 | + .filter(col("partition_id") === partition) |
| 1115 | + |
| 1116 | + if (partition == 1) { |
| 1117 | + checkAnswer(stateSnapshotDf, stateDf) |
| 1118 | + } else { |
| 1119 | + // Ensure that key 12 is not present in the final state loaded from given snapshot |
| 1120 | + val resultDfForSnapshot = stateSnapshotDf.selectExpr( |
| 1121 | + "key.value AS groupingKey", |
| 1122 | + "value.value AS count", |
| 1123 | + "partition_id") |
| 1124 | + checkAnswer(resultDfForSnapshot, |
| 1125 | + Seq(Row(16, 4L, 4), |
| 1126 | + Row(17, 1L, 4), |
| 1127 | + Row(19, 3L, 4), |
| 1128 | + Row(2, 2L, 4), |
| 1129 | + Row(6, 2L, 4), |
| 1130 | + Row(9, 1L, 4))) |
| 1131 | + |
| 1132 | + // Ensure that key 12 is present in the final state loaded from the latest snapshot |
| 1133 | + val resultDf = stateDf.selectExpr( |
| 1134 | + "key.value AS groupingKey", |
| 1135 | + "value.value AS count", |
| 1136 | + "partition_id") |
| 1137 | + |
| 1138 | + checkAnswer(resultDf, |
| 1139 | + Seq(Row(16, 4L, 4), |
| 1140 | + Row(17, 1L, 4), |
| 1141 | + Row(19, 3L, 4), |
| 1142 | + Row(2, 2L, 4), |
| 1143 | + Row(6, 2L, 4), |
| 1144 | + Row(9, 1L, 4), |
| 1145 | + Row(12, 4L, 4))) |
| 1146 | + } |
| 1147 | + } |
| 1148 | + } |
| 1149 | + } |
1000 | 1150 | }
|
0 commit comments