diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 4aa95ad42ec7f..101eb1e51d83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -204,7 +204,6 @@ class StatePartitionReader( } override def close(): Unit = { - store.abort() super.close() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 25a1ca249c223..146689e739f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -89,6 +89,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def abort(): Unit = {} + override def release(): Unit = {} + override def toString(): String = { s"HDFSReadStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } @@ -112,6 +114,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with case object UPDATING extends STATE case object COMMITTED extends STATE case object ABORTED extends STATE + case object RELEASED extends STATE + + + Option(TaskContext.get()).foreach { ctxt => + ctxt.addTaskCompletionListener[Unit](ctx => { + if (state == UPDATING) { + abort() + } + }) + } private val newVersion = version + 1 @volatile private var state: STATE = UPDATING @@ -194,6 +206,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)}") } + override def release(): Unit = { + state = RELEASED + } + /** * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. @@ -953,7 +969,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with * @param endVersion checkpoint version to end with * @return [[HDFSBackedStateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean): StateStore = { val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion) logInfo(log"Retrieved snapshot at version " + log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 6b3bec2077037..9cfb84fdcfb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1013,6 +1013,12 @@ class RocksDB( } } + def release(): Unit = { + if (db != null) { + release(LoadStore) + } + } + /** * Commit all the updates made as a version to DFS. The steps it needs to do to commits are: * - Flush all changes to disk diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 6efdf8d67137b..0ef39a2afc2fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec @@ -43,12 +43,13 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + class RocksDBStateStore(lastVersion: Long, var readOnly: Boolean) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE case object ABORTED extends STATE + case object RELEASED extends STATE @volatile private var state: STATE = UPDATING @volatile private var isValidated = false @@ -57,6 +58,30 @@ private[sql] class RocksDBStateStoreProvider override def version: Long = lastVersion + Option(TaskContext.get()).foreach { ctxt => + ctxt.addTaskCompletionListener[Unit]( ctx => { + try { + if (state == UPDATING) { + if (readOnly) { + release() // Only release, do not throw an error because we rely on + // CompletionListener to release for read-only store in + // mapPartitionsWithReadStateStore. + } else { + abort() // Abort since this is an error if stateful task completes + } + } + } catch { + case NonFatal(e) => + logWarning("Failed to abort state store", e) + } + }) + + ctxt.addTaskFailureListener( (_, _) => { + abort() // Either the store is already aborted (this is a no-op) or + // we need to abort it. + }) + } + override def createColFamilyIfAbsent( colFamilyName: String, keySchema: StructType, @@ -365,6 +390,19 @@ private[sql] class RocksDBStateStoreProvider } result } + + override def release(): Unit = { + assert(readOnly, "Release can only be called on a read-only store") + if (state != RELEASED) { + logInfo(log"Releasing ${MDC(VERSION_NUM, version + 1)} " + + log"for ${MDC(STATE_STORE_ID, id)}") + rocksDB.release() + state = RELEASED + } else { + // Optionally log at DEBUG level that it's already released + logDebug(log"State store already released") + } + } } // Test-visible method to fetch the internal RocksDBStateStore class @@ -446,38 +484,52 @@ private[sql] class RocksDBStateStoreProvider override def stateStoreId: StateStoreId = stateStoreId_ - override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + /** + * Creates and returns a state store with the specified parameters. + * + * @param version The version of the state store to load + * @param uniqueId Optional unique identifier for checkpoint + * @param readOnly Whether to open the store in read-only mode + * @param existingStore Optional existing store to reuse instead of creating a new one + * @return The loaded state store + */ + private def loadStateStore( + version: Long, + uniqueId: Option[String], + readOnly: Boolean, + existingStore: Option[ReadStateStore] = None): StateStore = { try { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None) - new RocksDBStateStore(version) - } - catch { - case e: OutOfMemoryError => - throw QueryExecutionErrors.notEnoughMemoryToLoadStore( - stateStoreId.toString, - "ROCKSDB_STORE_PROVIDER", - e) - case e: Throwable => throw StateStoreErrors.cannotLoadStore(e) - } - } - override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = { - try { - if (version < 0) { - throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + // Early validation of the existing store type before loading RocksDB + existingStore.foreach { store => + if (!store.isInstanceOf[RocksDBStateStore]) { + throw new IllegalArgumentException( + s"Existing store must be a RocksDBStateStore, but got ${store.getClass.getSimpleName}") + } } + + // Load RocksDB store rocksDB.load( version, stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, - readOnly = true) - new RocksDBStateStore(version) - } - catch { + readOnly = readOnly) + + // Create or reuse store instance + existingStore match { + case Some(store: ReadStateStore) if store.isInstanceOf[RocksDBStateStore] => + // Mark store as being used for write operations + val rocksDBStateStore = store.asInstanceOf[RocksDBStateStore] + rocksDBStateStore.readOnly = readOnly + rocksDBStateStore.asInstanceOf[StateStore] + case None => + // Create new store instance + new RocksDBStateStore(version, readOnly) + case _ => null // No need for error case here since we validated earlier + } + } catch { case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( stateStoreId.toString, @@ -487,6 +539,26 @@ private[sql] class RocksDBStateStoreProvider } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = false) + } + + override def upgradeReadStoreToWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = { + assert(version == readStore.version, + s"Can only upgrade readStore to writeStore with the same version," + + s" readStoreVersion: ${readStore.version}, writeStoreVersion: ${version}") + assert(this.stateStoreId == readStore.id, "Can only upgrade readStore to writeStore with" + + " the same stateStoreId") + loadStateStore(version, uniqueId, readOnly = false, existingStore = Some(readStore)) + } + + override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = true) + } + override def doMaintenance(): Unit = { try { rocksDB.doMaintenance() @@ -572,7 +644,8 @@ private[sql] class RocksDBStateStoreProvider * @param endVersion checkpoint version to end with * @return [[StateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean): StateStore = { try { if (snapshotVersion < 1) { throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion) @@ -581,7 +654,7 @@ private[sql] class RocksDBStateStoreProvider throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion) } rocksDB.loadFromSnapshot(snapshotVersion, endVersion) - new RocksDBStateStore(endVersion) + new RocksDBStateStore(endVersion, readOnly) } catch { case e: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 3e14d02b73da5..da8be4cef9368 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -117,6 +117,21 @@ trait ReadStateStore { * The method name is to respect backward compatibility on [[StateStore]]. */ def abort(): Unit + + /** + * Releases resources associated with this read-only state store. + * + * This method should be called when the store is no longer needed but has completed + * successfully (i.e., no errors occurred during reading). It performs any necessary + * cleanup operations without invalidating or rolling back the data that was read. + * + * In contrast to `abort()`, which is called on error paths to cancel operations, + * `release()` is the proper method to call in success scenarios when a read-only + * store is no longer needed. + * + * This method is idempotent and safe to call multiple times. + */ + def release(): Unit } /** @@ -234,6 +249,8 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { override def abort(): Unit = store.abort() + override def release(): Unit = {} + override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = store.prefixScan(prefixKey, colFamilyName) @@ -565,6 +582,29 @@ trait StateStoreProvider { version: Long, stateStoreCkptId: Option[String] = None): StateStore + /** + * Creates a writable store from an existing read-only store for the specified version. + * + * This method enables an important optimization pattern for stateful operations where + * the same state store needs to be accessed for both reading and writing within a task. + * Instead of opening two separate state store instances (which can cause contention issues), + * this method converts an existing read-only store to a writable store that can commit changes. + * + * This approach is particularly beneficial when: + * - A stateful operation needs to first read the existing state, then update it + * - The state store has locking mechanisms that prevent concurrent access + * - Multiple state store connections would cause unnecessary resource duplication + * + * @param readStore The existing read-only store instance to convert to a writable store + * @param version The version of the state store (must match the read store's version) + * @param uniqueId Optional unique identifier for checkpointing + * @return A writable StateStore instance that can be used to update and commit changes + */ + def upgradeReadStoreToWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = getStore(version, uniqueId) + /** * Return an instance of [[ReadStateStore]] representing state data of the given version * and uniqueID if provided. @@ -720,7 +760,8 @@ trait SupportsFineGrainedReplay { * @param snapshotVersion checkpoint version of the snapshot to start with * @param endVersion checkpoint version to end with */ - def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore + def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean = false): StateStore /** * Return an instance of [[ReadStateStore]] representing state data of the given version. @@ -733,7 +774,7 @@ trait SupportsFineGrainedReplay { * @param endVersion checkpoint version to end with */ def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = { - new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion)) + new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion, readOnly = true)) } /** @@ -824,7 +865,6 @@ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { } } - /** * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), @@ -951,6 +991,56 @@ object StateStore extends Logging { storeProvider.getReadStore(version, stateStoreCkptId) } + /** + * Converts an existing read-only state store to a writable state store. + * + * This method provides an optimization for stateful operations that need to both read and update + * state within the same task. Instead of opening separate read and write instances (which may + * cause resource contention or duplication), this method reuses the already loaded read store + * and transforms it into a writable store. + * + * The optimization is particularly valuable for state stores with expensive initialization costs + * or limited concurrency capabilities (like RocksDB). It eliminates redundant loading of the same + * state data and reduces resource usage. + * + * @param readStore The existing read-only state store to convert to a writable store + * @param storeProviderId Unique identifier for the state store provider + * @param keySchema Schema of the state store keys + * @param valueSchema Schema of the state store values + * @param keyStateEncoderSpec Specification for encoding the state keys + * @param version The version of the state store (must match the read store's version) + * @param stateStoreCkptId Optional checkpoint identifier for the state store + * @param stateSchemaBroadcast Optional broadcast of the state schema + * @param useColumnFamilies Whether to use column families in the state store + * @param storeConf Configuration for the state store + * @param hadoopConf Hadoop configuration + * @param useMultipleValuesPerKey Whether the store supports multiple values per key + * @return A writable StateStore instance that can be used to update and commit changes + * @throws SparkException If the store cannot be loaded or if there's insufficient memory + */ + def getWriteStore( + readStore: ReadStateStore, + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + version: Long, + stateStoreCkptId: Option[String], + stateSchemaBroadcast: Option[StateSchemaBroadcast], + useColumnFamilies: Boolean, + storeConf: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): StateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) + if (version < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + } + val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, + keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey, + stateSchemaBroadcast) + storeProvider.upgradeReadStoreToWriteStore(readStore, version, stateStoreCkptId) + } + /** Get or create a store associated with the id. */ def get( storeProviderId: StateStoreProviderId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 4a3e045811686..c95faada419e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -27,6 +27,28 @@ import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration +/** + * Thread local storage for sharing StateStore instances between RDDs. + * This allows a ReadStateStore to be reused by a subsequent StateStore operation. + */ +object StateStoreThreadLocalTracker { + /** Case class to hold both the store and its usage state */ + + private val storeInfo: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore] + + def setStore(store: ReadStateStore): Unit = { + storeInfo.set(store) + } + + def getStore: Option[ReadStateStore] = { + Option(storeInfo.get()) + } + + def clearStore(): Unit = { + storeInfo.remove() + } +} + abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], checkpointLocation: String, @@ -95,6 +117,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( stateStoreCkptIds.map(_.apply(partition.index).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value) + StateStoreThreadLocalTracker.setStore(store) storeReadFunction(store, inputIter) } } @@ -130,12 +153,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val storeProviderId = getStateProviderId(partition) val inputIter = dataRDD.iterator(partition, ctxt) - val store = StateStore.get( - storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, - uniqueId.map(_.apply(partition.index).head), - stateSchemaBroadcast, - useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, - useMultipleValuesPerKey) + val store = StateStoreThreadLocalTracker.getStore match { + case Some(readStateStore: ReadStateStore) => + StateStore.getWriteStore(readStateStore, storeProviderId, + keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + case None => + StateStore.get( + storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + } if (storeConf.unloadOnCommit) { ctxt.addTaskCompletionListener[Unit](_ => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index a82eff4812953..651c180299a9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -67,9 +67,6 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener[Unit](_ => { - if (!store.hasCommitted) store.abort() - }) cleanedF(store, iter) } @@ -109,8 +106,9 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeReadFn) val wrappedF = (store: ReadStateStore, iter: Iterator[T]) => { // Clean up the state store. - TaskContext.get().addTaskCompletionListener[Unit](_ => { - store.abort() + val ctxt = TaskContext.get() + ctxt.addTaskCompletionListener[Unit](_ => { + StateStoreThreadLocalTracker.clearStore() }) cleanedF(store, iter) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 9a04a0c759ac4..7af06e3ab9d46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -55,6 +55,8 @@ class MemoryStateStore extends StateStore() { override def abort(): Unit = {} + override def release(): Unit = {} + override def id: StateStoreId = null override def version: Long = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 22150ffde5db6..3f207cb61df39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -137,6 +137,8 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta ret } override def hasCommitted: Boolean = innerStore.hasCommitted + + override def release(): Unit = {} } class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 1f9fd17eda600..2247fd21da597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -228,6 +228,147 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { } } + test("SPARK-51955: ReadStateStore reuse and upgrade to WriteStore") { + withSparkSession(SparkSession.builder() + .config(sparkConf) + .config(SQLConf.STATE_STORE_PROVIDER_CLASS.key, classOf[RocksDBStateStoreProvider].getName) + .config(SQLConf.SHUFFLE_PARTITIONS.key, "1") + .getOrCreate()) { spark => + implicit val sqlContext = spark.sqlContext + val path = Utils.createDirectory(tempDir, Random.nextFloat().toString).toString + + // Create initial data in the state store (version 0) + val initialData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0))) + val setupRDD = initialData.mapPartitionsWithStateStore( + sqlContext, + operatorStateInfo(path, version = 0), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema) + ) { (store, iter) => + // Set initial values: a->1, b->2 + iter.foreach { case (s, i) => + val key = dataToKeyRow(s, i) + store.put(key, dataToValueRow(if (s == "a") 1 else 2)) + } + store.commit() + Iterator.empty + } + setupRDD.count() // Force evaluation + + // Create input data for our chained operations + val inputData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + + var mappedReadStore: ReadStateStore = null + var mappedWriteStore: StateStore = null + + // Chain operations: first read with ReadStateStore, then write with StateStore + val chainedResults = inputData + // First pass: read-only state store access + .mapPartitionsWithReadStateStore( + operatorStateInfo(path, version = 1), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (readStore, iter) => + mappedReadStore = readStore + + // Read values and store them for later verification + val inputItems = iter.toSeq // Materialize the input data + + val readValues = inputItems.map { case (s, i) => + val key = dataToKeyRow(s, i) + val value = Option(readStore.get(key)).map(valueRowToData) + ((s, i), value) + } + + // Also capture all state store entries + val allValues = readStore.iterator().map(rowPairToDataPair).toSeq + + // Return everything as a single tuple - only create one element in the iterator + Iterator((readValues, allValues, inputItems)) + } + // Second pass: use StateStore to write updates (should reuse the read store) + .mapPartitionsWithStateStore( + operatorStateInfo(path, version = 1), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (writeStore, writeIter) => + if (writeIter.hasNext) { + val (readValues, allStoreValues, originalItems) = writeIter.next() + mappedWriteStore = writeStore + // Get all existing values from the write store to verify reuse + val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq + + // Update values for a and c from the original items + originalItems.filter(p => p._1 == "a" || p._1 == "c").foreach { case (s, i) => + val key = dataToKeyRow(s, i) + val oldValue = Option(writeStore.get(key)).map(valueRowToData).getOrElse(0) + val newValue = oldValue + 10 // Add 10 to existing values + writeStore.put(key, dataToValueRow(newValue)) + } + writeStore.commit() + + // Return all collected information for verification + Iterator((readValues, allStoreValues, storeValues)) + } else { + Iterator.empty + } + } + + // Collect the results + val (readValues, initialStoreState, writeStoreValues) = chainedResults.collect().head + + // Verify read results + assert(readValues.toSet === Set( + ("a", 0) -> Some(1), + ("b", 0) -> Some(2), + ("c", 0) -> None + )) + + // Verify store state matches expected values + assert(initialStoreState.toSet === Set((("a", 0), 1), (("b", 0), 2))) + + // Verify the existing values in the write store (should be the same as initial state) + assert(writeStoreValues.toSet === Set((("a", 0), 1), (("b", 0), 2))) + + // Verify that the same store was used for both read and write operations + assert(mappedReadStore == mappedWriteStore, + "StateStoreThreadLocalTracker should indicate the read store was reused") + + // Create another ReadStateStoreRDD to verify the final state (version 2) + val verifyData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + val verifyRDD = verifyData.mapPartitionsWithReadStateStore( + operatorStateInfo(path, version = 2), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (store, iter) => + iter.map { case (s, i) => + val key = dataToKeyRow(s, i) + val value = Option(store.get(key)).map(valueRowToData) + ((s, i), value) + } + } + + // Verify the final state has the expected values + // a: 1 + 10 = 11, b: 2 (unchanged), c: 0 + 10 = 10 + val finalResults = verifyRDD.collect().toSet + assert(finalResults === Set( + ("a", 0) -> Some(11), + ("b", 0) -> Some(2), + ("c", 0) -> Some(10) + )) + } + } + test("SPARK-51823: unload on commit") { withSparkSession( SparkSession.builder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 4226ee94e98d3..b82f48c262603 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1389,7 +1389,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) saveStore.commit() - restoreStore.abort() + // We don't need to call restoreStore.release() since the Write Store has been committed } // check that state is correct for next batch @@ -1707,6 +1707,107 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } + test("two concurrent StateStores - one for read-only and one for read-write with release()") { + val dir = Utils.createTempDir().getAbsolutePath + val storeId = StateStoreId(dir, 0L, 1) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID) + val key1 = "a" + val key2 = 0 + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + quietly { + withSpark(SparkContext.getOrCreate( + new SparkConf().setMaster("local").setAppName("test"))) { sc => + withCoordinatorRef(sc) { _ => + // Prime state + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + put(store, key1, key2, 1) + store.commit() + + // Get two state stores - one read-only and one read-write + val restoreStore = StateStore.getReadOnly( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + val saveStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // Update the write store based on data from read store + put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) + saveStore.commit() + + // Check that state is correct for next batch + val finalStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 2, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + assert(get(finalStore, key1, key2) === Some(2)) + } + } + } + } + + test("getWriteStore correctly uses existing read store") { + val dir = Utils.createTempDir().getAbsolutePath + val storeId = StateStoreId(dir, 0L, 1) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + quietly { + withSpark(SparkContext.getOrCreate( + new SparkConf().setMaster("local").setAppName("test"))) { sc => + withCoordinatorRef(sc) { _ => + // Prime state + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + put(store, "a", 0, 1) + store.commit() + + // Get a read-only store + val readStore = StateStore.getReadOnly( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // Convert it to a write store using the new getWriteStore method + val writeStore = StateStore.getWriteStore( + readStore, + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // The write store should still have access to the data + assert(get(writeStore, "a", 0) === Some(1)) + + // Update and commit with the write store + put(writeStore, "a", 0, 2) + writeStore.commit() + + // Check that the state was updated correctly + val finalStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 2, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + assert(get(finalStore, "a", 0) === Some(2)) + } + } + } + } + test("SPARK-42572: StateStoreProvider.validateStateRowFormat shouldn't check" + " value row format when SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED is false") { // By default, when there is an invalid pair of value row and value schema, it should throw