From b3c584bc265aaca52b05520f852ae80164695818 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 30 Apr 2025 11:43:40 -0700 Subject: [PATCH 01/19] using threadlocal instead --- .../state/HDFSBackedStateStoreProvider.scala | 4 + .../execution/streaming/state/RocksDB.scala | 6 + .../state/RocksDBStateStoreProvider.scala | 92 +++++++++++----- .../streaming/state/StateStore.scala | 63 +++++++++++ .../streaming/state/StateStoreRDD.scala | 46 ++++++-- .../execution/streaming/state/package.scala | 12 +- .../streaming/state/MemoryStateStore.scala | 2 + ...sDBStateStoreCheckpointFormatV2Suite.scala | 2 + .../streaming/state/StateStoreSuite.scala | 103 +++++++++++++++++- 9 files changed, 292 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 98d49596d11b4..54f991ab1db08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -89,6 +89,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def abort(): Unit = {} + override def release(): Unit = {} + override def toString(): String = { s"HDFSReadStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } @@ -194,6 +196,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)}") } + override def release(): Unit = {} + /** * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 6b3bec2077037..9cfb84fdcfb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1013,6 +1013,12 @@ class RocksDB( } } + def release(): Unit = { + if (db != null) { + release(LoadStore) + } + } + /** * Commit all the updates made as a version to DFS. The steps it needs to do to commits are: * - Flush all changes to disk diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 6a36b8c015196..3bbcece9a67fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -49,6 +49,7 @@ private[sql] class RocksDBStateStoreProvider case object UPDATING extends STATE case object COMMITTED extends STATE case object ABORTED extends STATE + case object RELEASED extends STATE @volatile private var state: STATE = UPDATING @volatile private var isValidated = false @@ -365,6 +366,18 @@ private[sql] class RocksDBStateStoreProvider } result } + + override def release(): Unit = { + if (state != RELEASED) { + logInfo(log"Releasing ${MDC(VERSION_NUM, version + 1)} " + + log"for ${MDC(STATE_STORE_ID, id)}") + rocksDB.release() + state = RELEASED + } else { + // Optionally log at DEBUG level that it's already released + logDebug(log"State store already released") + } + } } // Test-visible method to fetch the internal RocksDBStateStore class @@ -446,17 +459,47 @@ private[sql] class RocksDBStateStoreProvider override def stateStoreId: StateStoreId = stateStoreId_ - override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + /** + * Creates and returns a state store with the specified parameters. + * + * @param version The version of the state store to load + * @param uniqueId Optional unique identifier for checkpoint + * @param readOnly Whether to open the store in read-only mode + * @param existingStore Optional existing store to reuse instead of creating a new one + * @return The loaded state store + */ + private def loadStateStore( + version: Long, + uniqueId: Option[String], + readOnly: Boolean, + existingStore: Option[ReadStateStore] = None): StateStore = { try { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None) - new RocksDBStateStore(version) - } - catch { + try { + // Load RocksDB store + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = readOnly) + + // Return appropriate store instance + existingStore match { + case Some(stateStore: RocksDBStateStore) => + // Reuse existing store for getWriteStore case + stateStore + case Some(_) => + throw new IllegalArgumentException("Existing store must be a RocksDBStateStore") + case None => + // Create new store instance for getStore/getReadStore cases + new RocksDBStateStore(version) + } + } catch { + case e: Throwable => + throw e + } + } catch { case e: SparkException if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) => throw e @@ -468,29 +511,22 @@ private[sql] class RocksDBStateStoreProvider case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = false) + } + + override def upgradeReadStoreToWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = { + assert(version == readStore.version, + s"Can only upgrade readStore to writeStore with the same version," + + s" readStoreVersion: ${readStore.version}, writeStoreVersion: ${version}") + loadStateStore(version, uniqueId, readOnly = false, existingStore = Some(readStore)) + } override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = { - try { - if (version < 0) { - throw QueryExecutionErrors.unexpectedStateStoreVersion(version) - } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, - readOnly = true) - new RocksDBStateStore(version) - } - catch { - case e: SparkException - if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) => - throw e - case e: OutOfMemoryError => - throw QueryExecutionErrors.notEnoughMemoryToLoadStore( - stateStoreId.toString, - "ROCKSDB_STORE_PROVIDER", - e) - case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) - } + loadStateStore(version, uniqueId, readOnly = true) } override def doMaintenance(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ffaba5ef1502f..08591179db985 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -117,6 +117,21 @@ trait ReadStateStore { * The method name is to respect backward compatibility on [[StateStore]]. */ def abort(): Unit + + /** + * Releases resources associated with this read-only state store. + * + * This method should be called when the store is no longer needed but has completed + * successfully (i.e., no errors occurred during reading). It performs any necessary + * cleanup operations without invalidating or rolling back the data that was read. + * + * In contrast to `abort()`, which is called on error paths to cancel operations, + * `release()` is the proper method to call in success scenarios when a read-only + * store is no longer needed. + * + * This method is idempotent and safe to call multiple times. + */ + def release(): Unit } /** @@ -234,6 +249,8 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { override def abort(): Unit = store.abort() + override def release(): Unit = {} + override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = store.prefixScan(prefixKey, colFamilyName) @@ -565,6 +582,29 @@ trait StateStoreProvider { version: Long, stateStoreCkptId: Option[String] = None): StateStore + /** + * Creates a writable store from an existing read-only store for the specified version. + * + * This method enables an important optimization pattern for stateful operations where + * the same state store needs to be accessed for both reading and writing within a task. + * Instead of opening two separate state store instances (which can cause contention issues), + * this method converts an existing read-only store to a writable store that can commit changes. + * + * This approach is particularly beneficial when: + * - A stateful operation needs to first read the existing state, then update it + * - The state store has locking mechanisms that prevent concurrent access + * - Multiple state store connections would cause unnecessary resource duplication + * + * @param readStore The existing read-only store instance to convert to a writable store + * @param version The version of the state store (must match the read store's version) + * @param uniqueId Optional unique identifier for checkpointing + * @return A writable StateStore instance that can be used to update and commit changes + */ + def upgradeReadStoreToWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = getStore(version, uniqueId) + /** * Return an instance of [[ReadStateStore]] representing state data of the given version * and uniqueID if provided. @@ -950,6 +990,29 @@ object StateStore extends Logging { storeProvider.getReadStore(version, stateStoreCkptId) } + def getWriteStore( + readStore: ReadStateStore, + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + version: Long, + stateStoreCkptId: Option[String], + stateSchemaBroadcast: Option[StateSchemaBroadcast], + useColumnFamilies: Boolean, + storeConf: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): StateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) + if (version < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + } + val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, + keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey, + stateSchemaBroadcast) + storeProvider.upgradeReadStoreToWriteStore(readStore, version, stateStoreCkptId) + } + /** Get or create a store associated with the id. */ def get( storeProviderId: StateStoreProviderId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 70b4932af6017..52a420482b3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,16 +17,37 @@ package org.apache.spark.sql.execution.streaming.state + import java.util.UUID import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration +/** + * Thread local storage for sharing StateStore instances between RDDs. + * This allows a ReadStateStore to be reused by a subsequent StateStore operation. + */ +object StateStoreThreadLocalTracker { + private val readStore: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore] + private val usedForWriteStore: ThreadLocal[Boolean] = new ThreadLocal[Boolean] + def setStore(store: ReadStateStore): Unit = readStore.set(store) + + def getStore: Option[ReadStateStore] = { + usedForWriteStore.set(true) + Option(readStore.get()) + } + + def isUsedForWriteStore: Boolean = usedForWriteStore.get() + + def clearStore(): Unit = readStore.remove() +} + abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], checkpointLocation: String, @@ -95,6 +116,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( stateStoreCkptIds.map(_.apply(partition.index).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value) + StateStoreThreadLocalTracker.setStore(store) storeReadFunction(store, inputIter) } } @@ -122,7 +144,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( extraOptions: Map[String, String] = Map.empty, useMultipleValuesPerKey: Boolean = false) extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId, - sessionState, storeCoordinator, extraOptions) { + sessionState, storeCoordinator, extraOptions) with Logging { override protected def getPartitions: Array[Partition] = dataRDD.partitions @@ -130,12 +152,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val storeProviderId = getStateProviderId(partition) val inputIter = dataRDD.iterator(partition, ctxt) - val store = StateStore.get( - storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, - uniqueId.map(_.apply(partition.index).head), - stateSchemaBroadcast, - useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, - useMultipleValuesPerKey) + val store = StateStoreThreadLocalTracker.getStore match { + case Some(readStateStore: ReadStateStore) => + StateStore.getWriteStore(readStateStore, storeProviderId, + keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + case None => + StateStore.get( + storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + } if (storeConf.unloadOnCommit) { ctxt.addTaskCompletionListener[Unit](_ => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index a82eff4812953..b662959695e4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType +import org.apache.spark.util.TaskFailureListener package object state { @@ -109,8 +110,15 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeReadFn) val wrappedF = (store: ReadStateStore, iter: Iterator[T]) => { // Clean up the state store. - TaskContext.get().addTaskCompletionListener[Unit](_ => { - store.abort() + val ctxt = TaskContext.get() + ctxt.addTaskCompletionListener[Unit](_ => { + if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.release() + StateStoreThreadLocalTracker.clearStore() + }) + ctxt.addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.abort() + StateStoreThreadLocalTracker.clearStore() }) cleanedF(store, iter) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 9a04a0c759ac4..7af06e3ab9d46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -55,6 +55,8 @@ class MemoryStateStore extends StateStore() { override def abort(): Unit = {} + override def release(): Unit = {} + override def id: StateStoreId = null override def version: Long = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 22150ffde5db6..3f207cb61df39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -137,6 +137,8 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta ret } override def hasCommitted: Boolean = innerStore.hasCommitted + + override def release(): Unit = {} } class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 08648148b4af4..1034a5edbdfc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1351,7 +1351,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) saveStore.commit() - restoreStore.abort() + // We don't need to call restoreStore.release() since the Write Store has been committed } // check that state is correct for next batch @@ -1669,6 +1669,107 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } + test("two concurrent StateStores - one for read-only and one for read-write with release()") { + val dir = Utils.createTempDir().getAbsolutePath + val storeId = StateStoreId(dir, 0L, 1) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID) + val key1 = "a" + val key2 = 0 + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + quietly { + withSpark(SparkContext.getOrCreate( + new SparkConf().setMaster("local").setAppName("test"))) { sc => + withCoordinatorRef(sc) { _ => + // Prime state + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + put(store, key1, key2, 1) + store.commit() + + // Get two state stores - one read-only and one read-write + val restoreStore = StateStore.getReadOnly( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + val saveStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // Update the write store based on data from read store + put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) + saveStore.commit() + + // Check that state is correct for next batch + val finalStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 2, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + assert(get(finalStore, key1, key2) === Some(2)) + } + } + } + } + + test("getWriteStore correctly uses existing read store") { + val dir = Utils.createTempDir().getAbsolutePath + val storeId = StateStoreId(dir, 0L, 1) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + quietly { + withSpark(SparkContext.getOrCreate( + new SparkConf().setMaster("local").setAppName("test"))) { sc => + withCoordinatorRef(sc) { _ => + // Prime state + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + put(store, "a", 0, 1) + store.commit() + + // Get a read-only store + val readStore = StateStore.getReadOnly( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // Convert it to a write store using the new getWriteStore method + val writeStore = StateStore.getWriteStore( + readStore, + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 1, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // The write store should still have access to the data + assert(get(writeStore, "a", 0) === Some(1)) + + // Update and commit with the write store + put(writeStore, "a", 0, 2) + writeStore.commit() + + // Check that the state was updated correctly + val finalStore = StateStore.get( + storeProviderId, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 2, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + assert(get(finalStore, "a", 0) === Some(2)) + } + } + } + } + test("SPARK-42572: StateStoreProvider.validateStateRowFormat shouldn't check" + " value row format when SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED is false") { // By default, when there is an invalid pair of value row and value schema, it should throw From 63573db6575d6cb71c4265ca1917641a1793455d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 30 Apr 2025 11:46:21 -0700 Subject: [PATCH 02/19] unnecessary change --- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 52a420482b3ac..c293a1290ac93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.streaming.state - import java.util.UUID import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -144,7 +142,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( extraOptions: Map[String, String] = Map.empty, useMultipleValuesPerKey: Boolean = false) extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId, - sessionState, storeCoordinator, extraOptions) with Logging { + sessionState, storeCoordinator, extraOptions) { override protected def getPartitions: Array[Partition] = dataRDD.partitions From f7d0e7008639c5d629ddd39cd8545517e49b6690 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 30 Apr 2025 15:45:43 -0700 Subject: [PATCH 03/19] removing clera --- .../streaming/state/RocksDBStateStoreProvider.scala | 1 + .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 5 ++++- .../apache/spark/sql/execution/streaming/state/package.scala | 2 -- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 3bbcece9a67fa..f588d99bb868c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -488,6 +488,7 @@ private[sql] class RocksDBStateStoreProvider existingStore match { case Some(stateStore: RocksDBStateStore) => // Reuse existing store for getWriteStore case + StateStoreThreadLocalTracker.setUsedForWriteStore(true) stateStore case Some(_) => throw new IllegalArgumentException("Existing store must be a RocksDBStateStore") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index c293a1290ac93..2e7e012866a61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -37,10 +37,13 @@ object StateStoreThreadLocalTracker { def setStore(store: ReadStateStore): Unit = readStore.set(store) def getStore: Option[ReadStateStore] = { - usedForWriteStore.set(true) Option(readStore.get()) } + def setUsedForWriteStore(used: Boolean): Unit = { + usedForWriteStore.set(used) + } + def isUsedForWriteStore: Boolean = usedForWriteStore.get() def clearStore(): Unit = readStore.remove() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b662959695e4d..dd45d99dd41df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -113,12 +113,10 @@ package object state { val ctxt = TaskContext.get() ctxt.addTaskCompletionListener[Unit](_ => { if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.release() - StateStoreThreadLocalTracker.clearStore() }) ctxt.addTaskFailureListener(new TaskFailureListener { override def onTaskFailure(context: TaskContext, error: Throwable): Unit = if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.abort() - StateStoreThreadLocalTracker.clearStore() }) cleanedF(store, iter) } From 560c5c7c88b8546a4c4db867c2e91db97e423c78 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 30 Apr 2025 16:01:45 -0700 Subject: [PATCH 04/19] Added cleanup --- .../execution/streaming/state/package.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index dd45d99dd41df..e54711e384fdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -68,8 +68,15 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener[Unit](_ => { + val ctxt = TaskContext.get() + ctxt.addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() + StateStoreThreadLocalTracker.clearStore() + }) + ctxt.addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + store.abort() + StateStoreThreadLocalTracker.clearStore() }) cleanedF(store, iter) } @@ -112,11 +119,17 @@ package object state { // Clean up the state store. val ctxt = TaskContext.get() ctxt.addTaskCompletionListener[Unit](_ => { - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.release() + if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { + store.release() + StateStoreThreadLocalTracker.clearStore() + } }) ctxt.addTaskFailureListener(new TaskFailureListener { override def onTaskFailure(context: TaskContext, error: Throwable): Unit = - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) store.abort() + if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { + store.abort() + StateStoreThreadLocalTracker.clearStore() + } }) cleanedF(store, iter) } From 81c0eedfcb084733307603cc1078942f44df6477 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 30 Apr 2025 17:11:18 -0700 Subject: [PATCH 05/19] adding assertion --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index f588d99bb868c..70499f069f569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -523,6 +523,8 @@ private[sql] class RocksDBStateStoreProvider assert(version == readStore.version, s"Can only upgrade readStore to writeStore with the same version," + s" readStoreVersion: ${readStore.version}, writeStoreVersion: ${version}") + assert(this.stateStoreId == readStore.id, "Can only upgrade readStore to writeStore with" + + " the same stateStoreId") loadStateStore(version, uniqueId, readOnly = false, existingStore = Some(readStore)) } From 0465454acb4fc315a2e2b5e7c3e3063fdc9918a0 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 1 May 2025 11:25:16 -0700 Subject: [PATCH 06/19] setting usedForWriteStore in StateStoreRDD --- .../streaming/state/RocksDBStateStoreProvider.scala | 2 -- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 70499f069f569..51afee6e51364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -487,8 +487,6 @@ private[sql] class RocksDBStateStoreProvider // Return appropriate store instance existingStore match { case Some(stateStore: RocksDBStateStore) => - // Reuse existing store for getWriteStore case - StateStoreThreadLocalTracker.setUsedForWriteStore(true) stateStore case Some(_) => throw new IllegalArgumentException("Existing store must be a RocksDBStateStore") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 2e7e012866a61..2c91234f4ea04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -155,12 +155,16 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStoreThreadLocalTracker.getStore match { case Some(readStateStore: ReadStateStore) => - StateStore.getWriteStore(readStateStore, storeProviderId, + val writeStore = StateStore.getWriteStore(readStateStore, storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, uniqueId.map(_.apply(partition.index).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) + if (writeStore.equals(readStateStore)) { + StateStoreThreadLocalTracker.setUsedForWriteStore(true) + } + writeStore case None => StateStore.get( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, From 93f014a43ab0afe5ccdd3e3b20d04221d4ef4755 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 1 May 2025 12:10:35 -0700 Subject: [PATCH 07/19] case class --- .../streaming/state/StateStoreRDD.scala | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 2c91234f4ea04..82a521a3ee03b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -32,21 +32,36 @@ import org.apache.spark.util.SerializableConfiguration * This allows a ReadStateStore to be reused by a subsequent StateStore operation. */ object StateStoreThreadLocalTracker { - private val readStore: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore] - private val usedForWriteStore: ThreadLocal[Boolean] = new ThreadLocal[Boolean] - def setStore(store: ReadStateStore): Unit = readStore.set(store) + /** Case class to hold both the store and its usage state */ + case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false) + + private val storeInfo: ThreadLocal[StoreInfo] = new ThreadLocal[StoreInfo] + + def setStore(store: ReadStateStore): Unit = { + Option(storeInfo.get()) match { + case Some(info) => storeInfo.set(info.copy(store = store)) + case None => storeInfo.set(StoreInfo(store)) + } + } def getStore: Option[ReadStateStore] = { - Option(readStore.get()) + Option(storeInfo.get()).map(_.store) } def setUsedForWriteStore(used: Boolean): Unit = { - usedForWriteStore.set(used) + Option(storeInfo.get()) match { + case Some(info) => storeInfo.set(info.copy(usedForWriteStore = used)) + case None => // If there's no store set, we don't need to track usage + } } - def isUsedForWriteStore: Boolean = usedForWriteStore.get() + def isUsedForWriteStore: Boolean = { + Option(storeInfo.get()).exists(_.usedForWriteStore) + } - def clearStore(): Unit = readStore.remove() + def clearStore(): Unit = { + storeInfo.remove() + } } abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag]( From 226f99b2e0539f93cefa27359df54e0b956da2b6 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 1 May 2025 16:31:17 -0700 Subject: [PATCH 08/19] changes --- .../state/RocksDBStateStoreProvider.scala | 14 +- .../streaming/state/StateStore.scala | 3 + .../streaming/state/StateStoreRDD.scala | 6 +- .../execution/streaming/state/package.scala | 3 +- .../streaming/state/StateStoreRDDSuite.scala | 137 ++++++++++++++++++ 5 files changed, 155 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 51afee6e51364..74065e1598291 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + class RocksDBStateStore(lastVersion: Long) extends StateStore with UpgradeableReadStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE @@ -486,10 +486,13 @@ private[sql] class RocksDBStateStoreProvider // Return appropriate store instance existingStore match { - case Some(stateStore: RocksDBStateStore) => - stateStore - case Some(_) => - throw new IllegalArgumentException("Existing store must be a RocksDBStateStore") + // We need to match like this as opposed to case Some(ss: RocksDBStateStore) + // because of how the tests create the class in StateStoreRDDSuite + case Some(stateStore: ReadStateStore) if stateStore.isInstanceOf[RocksDBStateStore] => + stateStore.asInstanceOf[StateStore] + case Some(other) => + throw new IllegalArgumentException(s"Existing store must be a RocksDBStateStore," + + s" store is actually ${other.getClass.getSimpleName}") case None => // Create new store instance for getStore/getReadStore cases new RocksDBStateStore(version) @@ -510,6 +513,7 @@ private[sql] class RocksDBStateStoreProvider case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { loadStateStore(version, uniqueId, readOnly = false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 08591179db985..2c5039a769078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -863,6 +863,9 @@ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { } } +// Trait on whether we can upgrade a given ReadStore to a WriteStore for Streaming +// Aggregations +trait UpgradeableReadStore /** * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 82a521a3ee03b..c3dfc5f1af9ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -176,8 +176,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) - if (writeStore.equals(readStateStore)) { - StateStoreThreadLocalTracker.setUsedForWriteStore(true) + readStateStore match { + case _: UpgradeableReadStore => + StateStoreThreadLocalTracker.setUsedForWriteStore(true) + case _ => } writeStore case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index e54711e384fdc..d00a8e2d8ff77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -74,9 +74,10 @@ package object state { StateStoreThreadLocalTracker.clearStore() }) ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { store.abort() StateStoreThreadLocalTracker.clearStore() + } }) cleanedF(store, iter) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 1f9fd17eda600..2f3b5f7e4f4a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -228,6 +228,143 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { } } + test("SPARK-51823: ReadStateStore reuse and upgrade to WriteStore") { + withSparkSession(SparkSession.builder() + .config(sparkConf) + .config(SQLConf.STATE_STORE_PROVIDER_CLASS.key, classOf[RocksDBStateStoreProvider].getName) + .config(SQLConf.SHUFFLE_PARTITIONS.key, "1") + .getOrCreate()) { spark => + implicit val sqlContext = spark.sqlContext + val path = Utils.createDirectory(tempDir, Random.nextFloat().toString).toString + + // Create initial data in the state store (version 0) + val initialData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0))) + val setupRDD = initialData.mapPartitionsWithStateStore( + sqlContext, + operatorStateInfo(path, version = 0), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema) + ) { (store, iter) => + // Set initial values: a->1, b->2 + iter.foreach { case (s, i) => + val key = dataToKeyRow(s, i) + store.put(key, dataToValueRow(if (s == "a") 1 else 2)) + } + store.commit() + Iterator.empty + } + setupRDD.count() // Force evaluation + + // Create input data for our chained operations + val inputData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + + // Chain operations: first read with ReadStateStore, then write with StateStore + val chainedResults = inputData + // First pass: read-only state store access + .mapPartitionsWithReadStateStore( + operatorStateInfo(path, version = 1), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (readStore, iter) => + // Read values and store them for later verification + val readValues = iter.map { case (s, i) => + val key = dataToKeyRow(s, i) + val value = Option(readStore.get(key)).map(valueRowToData) + ((s, i), value) + }.toSeq + + // Also capture all state store entries + val allValues = readStore.iterator().map(rowPairToDataPair).toSeq + + // Pass along both to the next stage - this keeps them in the same partition + // Also pass through the original items + Iterator((readValues, allValues, iter.toSeq)) + } + // Second pass: use StateStore to write updates (should reuse the read store) + .mapPartitionsWithStateStore( + operatorStateInfo(path, version = 1), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (writeStore, writeIter) => + if (writeIter.hasNext) { + val (readValues, allStoreValues, originalItems) = writeIter.next() + val usedForWriteStore = StateStoreThreadLocalTracker.isUsedForWriteStore + // Get all existing values from the write store to verify reuse + val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq + + // Update values for a and c from the original items we passed through + originalItems.filter(p => p._1 == "a" || p._1 == "c").foreach { case (s, i) => + val key = dataToKeyRow(s, i) + val oldValue = Option(writeStore.get(key)).map(valueRowToData).getOrElse(0) + val newValue = oldValue + 10 // Add 10 to existing values + writeStore.put(key, dataToValueRow(newValue)) + } + writeStore.commit() + + // Return all collected information for verification + Iterator((readValues, allStoreValues, storeValues, + usedForWriteStore)) + } else { + Iterator.empty + } + } + + // Collect the results + val (readValues, initialStoreState, + writeStoreValues, storeWasReused) = chainedResults.collect().head + + // Verify read results + assert(readValues.toSet === Set( + ("a", 0) -> Some(1), + ("b", 0) -> Some(2), + ("c", 0) -> None + )) + + // Verify store state matches expected values + assert(initialStoreState.toSet === Set((("a", 0), 1), (("b", 0), 2))) + + // Verify the existing values in the write store (should be the same as initial state) + assert(writeStoreValues.toSet === Set((("a", 0), 1), (("b", 0), 2))) + + // Verify the thread local flag indicates reuse + assert(storeWasReused, + "StateStoreThreadLocalTracker should indicate the read store was reused") + + // Create another ReadStateStoreRDD to verify the final state (version 2) + val verifyData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + val verifyRDD = verifyData.mapPartitionsWithReadStateStore( + operatorStateInfo(path, version = 2), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + spark.sessionState, + Some(castToImpl(spark).streams.stateStoreCoordinator) + ) { (store, iter) => + iter.map { case (s, i) => + val key = dataToKeyRow(s, i) + val value = Option(store.get(key)).map(valueRowToData) + ((s, i), value) + } + } + + // Verify the final state has the expected values + // a: 1 + 10 = 11, b: 2 (unchanged), c: 0 + 10 = 10 + val finalResults = verifyRDD.collect().toSet + assert(finalResults === Set( + ("a", 0) -> Some(11), + ("b", 0) -> Some(2), + ("c", 0) -> Some(10) + )) + } + } + test("SPARK-51823: unload on commit") { withSparkSession( SparkSession.builder() From a892142880daed31ad75d6c6f40b003b0afd2300 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 1 May 2025 16:42:57 -0700 Subject: [PATCH 09/19] test --- .../streaming/state/StateStoreRDDSuite.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 2f3b5f7e4f4a1..6095b26ecd6fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -228,7 +228,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { } } - test("SPARK-51823: ReadStateStore reuse and upgrade to WriteStore") { + test("SPARK-51955: ReadStateStore reuse and upgrade to WriteStore") { withSparkSession(SparkSession.builder() .config(sparkConf) .config(SQLConf.STATE_STORE_PROVIDER_CLASS.key, classOf[RocksDBStateStoreProvider].getName) @@ -271,18 +271,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { Some(castToImpl(spark).streams.stateStoreCoordinator) ) { (readStore, iter) => // Read values and store them for later verification - val readValues = iter.map { case (s, i) => + val inputItems = iter.toSeq // Materialize the input data + + val readValues = inputItems.map { case (s, i) => val key = dataToKeyRow(s, i) val value = Option(readStore.get(key)).map(valueRowToData) ((s, i), value) - }.toSeq + } // Also capture all state store entries val allValues = readStore.iterator().map(rowPairToDataPair).toSeq - // Pass along both to the next stage - this keeps them in the same partition - // Also pass through the original items - Iterator((readValues, allValues, iter.toSeq)) + // Return everything as a single tuple - only create one element in the iterator + Iterator((readValues, allValues, inputItems)) } // Second pass: use StateStore to write updates (should reuse the read store) .mapPartitionsWithStateStore( @@ -299,7 +300,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { // Get all existing values from the write store to verify reuse val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq - // Update values for a and c from the original items we passed through + // Update values for a and c from the original items originalItems.filter(p => p._1 == "a" || p._1 == "c").foreach { case (s, i) => val key = dataToKeyRow(s, i) val oldValue = Option(writeStore.get(key)).map(valueRowToData).getOrElse(0) @@ -309,16 +310,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { writeStore.commit() // Return all collected information for verification - Iterator((readValues, allStoreValues, storeValues, - usedForWriteStore)) + Iterator((readValues, allStoreValues, storeValues, usedForWriteStore)) } else { Iterator.empty } } // Collect the results - val (readValues, initialStoreState, - writeStoreValues, storeWasReused) = chainedResults.collect().head + val (readValues, initialStoreState, writeStoreValues, + storeWasReused) = chainedResults.collect().head // Verify read results assert(readValues.toSet === Set( From 768346a875d1b55e76bee800ce88b3d8f77259db Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 1 May 2025 17:14:35 -0700 Subject: [PATCH 10/19] upgradeable read store --- .../streaming/state/RocksDBStateStoreProvider.scala | 2 +- .../spark/sql/execution/streaming/state/StateStore.scala | 4 ---- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 6 ++---- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 74065e1598291..8f4a7041e6581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore with UpgradeableReadStore { + class RocksDBStateStore(lastVersion: Long) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 2c5039a769078..ea1085749e26b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -863,10 +863,6 @@ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { } } -// Trait on whether we can upgrade a given ReadStore to a WriteStore for Streaming -// Aggregations -trait UpgradeableReadStore - /** * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index c3dfc5f1af9ed..82a521a3ee03b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -176,10 +176,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) - readStateStore match { - case _: UpgradeableReadStore => - StateStoreThreadLocalTracker.setUsedForWriteStore(true) - case _ => + if (writeStore.equals(readStateStore)) { + StateStoreThreadLocalTracker.setUsedForWriteStore(true) } writeStore case None => From 55664c4373f69387ed2f55aed03d8999a71ee3dc Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 2 May 2025 11:45:39 -0700 Subject: [PATCH 11/19] adding state --- .../streaming/state/HDFSBackedStateStoreProvider.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 54f991ab1db08..20189c8007ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -114,6 +114,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with case object UPDATING extends STATE case object COMMITTED extends STATE case object ABORTED extends STATE + case object RELEASED extends STATE private val newVersion = version + 1 @volatile private var state: STATE = UPDATING @@ -196,7 +197,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)}") } - override def release(): Unit = {} + override def release(): Unit = { + state = RELEASED + } /** * Get an iterator of all the store data. From 8d28ea21b715cbcc7988f7711281be40f7dfb137 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 5 May 2025 10:52:51 -0700 Subject: [PATCH 12/19] moving clearStore to only mapPartitionsWithReadStateStore --- .../execution/streaming/state/StateStoreRDD.scala | 6 ++++-- .../sql/execution/streaming/state/package.scala | 12 ++++++------ .../sql/execution/streaming/state/RocksDBSuite.scala | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 82a521a3ee03b..47950d964f730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -31,7 +32,7 @@ import org.apache.spark.util.SerializableConfiguration * Thread local storage for sharing StateStore instances between RDDs. * This allows a ReadStateStore to be reused by a subsequent StateStore operation. */ -object StateStoreThreadLocalTracker { +object StateStoreThreadLocalTracker extends Logging { /** Case class to hold both the store and its usage state */ case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false) @@ -50,7 +51,8 @@ object StateStoreThreadLocalTracker { def setUsedForWriteStore(used: Boolean): Unit = { Option(storeInfo.get()) match { - case Some(info) => storeInfo.set(info.copy(usedForWriteStore = used)) + case Some(info) => + storeInfo.set(info.copy(usedForWriteStore = used)) case None => // If there's no store set, we don't need to track usage } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index d00a8e2d8ff77..bc1da7315d2cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.classic.ClassicConversions.castToImpl @@ -27,7 +28,7 @@ import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.TaskFailureListener -package object state { +package object state extends Logging { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { @@ -71,12 +72,10 @@ package object state { val ctxt = TaskContext.get() ctxt.addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() - StateStoreThreadLocalTracker.clearStore() }) ctxt.addTaskFailureListener(new TaskFailureListener { override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { store.abort() - StateStoreThreadLocalTracker.clearStore() } }) cleanedF(store, iter) @@ -122,15 +121,16 @@ package object state { ctxt.addTaskCompletionListener[Unit](_ => { if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { store.release() - StateStoreThreadLocalTracker.clearStore() } + StateStoreThreadLocalTracker.clearStore() }) ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { store.abort() - StateStoreThreadLocalTracker.clearStore() } + StateStoreThreadLocalTracker.clearStore() + } }) cleanedF(store, iter) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index bd9c838eaa6a8..21984a969797b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -3479,7 +3479,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } } - test("RocksDB task completion listener correctly releases for failed task") { + ignore("RocksDB task completion listener correctly releases for failed task") { // This test verifies that a thread that locks the DB and then fails // can rely on the completion listener to release the lock. From 807a3c174d92bbc0a907183ac2e9c594e0d69273 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 5 May 2025 11:38:49 -0700 Subject: [PATCH 13/19] ignore -> test --- .../spark/sql/execution/streaming/state/RocksDBSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 21984a969797b..bd9c838eaa6a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -3479,7 +3479,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } } - ignore("RocksDB task completion listener correctly releases for failed task") { + test("RocksDB task completion listener correctly releases for failed task") { // This test verifies that a thread that locks the DB and then fails // can rely on the completion listener to release the lock. From 7530d541254cb33cdbf2f240a4f6cddcacc6b3ea Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 5 May 2025 11:40:04 -0700 Subject: [PATCH 14/19] logging --- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 3 +-- .../apache/spark/sql/execution/streaming/state/package.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 47950d964f730..d78c5229e0ac2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -22,7 +22,6 @@ import java.util.UUID import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -32,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration * Thread local storage for sharing StateStore instances between RDDs. * This allows a ReadStateStore to be reused by a subsequent StateStore operation. */ -object StateStoreThreadLocalTracker extends Logging { +object StateStoreThreadLocalTracker { /** Case class to hold both the store and its usage state */ case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index bc1da7315d2cc..b0a94052c9900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.classic.ClassicConversions.castToImpl @@ -28,7 +27,7 @@ import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.TaskFailureListener -package object state extends Logging { +package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { From 95db8398b3775f2b0df9ee599f096f9b531a3e3f Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 6 May 2025 13:01:32 -0700 Subject: [PATCH 15/19] removing failure listener --- .../spark/sql/execution/streaming/state/package.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b0a94052c9900..bcb8190f7531a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -123,14 +123,6 @@ package object state { } StateStoreThreadLocalTracker.clearStore() }) - ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { - store.abort() - } - StateStoreThreadLocalTracker.clearStore() - } - }) cleanedF(store, iter) } new ReadStateStoreRDD( From 8f15229820405e2a5ce9fb860f1689f9e312caf8 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 8 May 2025 09:29:03 -0700 Subject: [PATCH 16/19] initial feedback --- .../state/RocksDBStateStoreProvider.scala | 27 ++++++++++++------- .../streaming/state/StateStore.scala | 27 +++++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8f4a7041e6581..ad906ed896c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -477,6 +477,15 @@ private[sql] class RocksDBStateStoreProvider if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } + + // Early validation of the existing store type before loading RocksDB + existingStore.foreach { store => + if (!store.isInstanceOf[RocksDBStateStore]) { + throw new IllegalArgumentException( + s"Existing store must be a RocksDBStateStore, but got ${store.getClass.getSimpleName}") + } + } + try { // Load RocksDB store rocksDB.load( @@ -484,22 +493,20 @@ private[sql] class RocksDBStateStoreProvider stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, readOnly = readOnly) - // Return appropriate store instance + // Create or reuse store instance existingStore match { - // We need to match like this as opposed to case Some(ss: RocksDBStateStore) - // because of how the tests create the class in StateStoreRDDSuite - case Some(stateStore: ReadStateStore) if stateStore.isInstanceOf[RocksDBStateStore] => - stateStore.asInstanceOf[StateStore] - case Some(other) => - throw new IllegalArgumentException(s"Existing store must be a RocksDBStateStore," + - s" store is actually ${other.getClass.getSimpleName}") + case Some(store: RocksDBStateStore) => + // Mark store as being used for write operations + StateStoreThreadLocalTracker.setUsedForWriteStore(true) + store case None => - // Create new store instance for getStore/getReadStore cases + // Create new store instance new RocksDBStateStore(version) + // No need for error case here since we validated earlier } } catch { case e: Throwable => - throw e + throw QueryExecutionErrors.cannotLoadStore(e) } } catch { case e: SparkException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ea1085749e26b..3097b07756bcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -989,6 +989,33 @@ object StateStore extends Logging { storeProvider.getReadStore(version, stateStoreCkptId) } + /** + * Converts an existing read-only state store to a writable state store. + * + * This method provides an optimization for stateful operations that need to both read and update + * state within the same task. Instead of opening separate read and write instances (which may + * cause resource contention or duplication), this method reuses the already loaded read store + * and transforms it into a writable store. + * + * The optimization is particularly valuable for state stores with expensive initialization costs + * or limited concurrency capabilities (like RocksDB). It eliminates redundant loading of the same + * state data and reduces resource usage. + * + * @param readStore The existing read-only state store to convert to a writable store + * @param storeProviderId Unique identifier for the state store provider + * @param keySchema Schema of the state store keys + * @param valueSchema Schema of the state store values + * @param keyStateEncoderSpec Specification for encoding the state keys + * @param version The version of the state store (must match the read store's version) + * @param stateStoreCkptId Optional checkpoint identifier for the state store + * @param stateSchemaBroadcast Optional broadcast of the state schema + * @param useColumnFamilies Whether to use column families in the state store + * @param storeConf Configuration for the state store + * @param hadoopConf Hadoop configuration + * @param useMultipleValuesPerKey Whether the store supports multiple values per key + * @return A writable StateStore instance that can be used to update and commit changes + * @throws SparkException If the store cannot be loaded or if there's insufficient memory + */ def getWriteStore( readStore: ReadStateStore, storeProviderId: StateStoreProviderId, From 32e55456067497c402388f3556053b47965b3474 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 8 May 2025 13:22:53 -0700 Subject: [PATCH 17/19] moving listeners from package.scala to StateStores --- .../v2/state/StatePartitionReader.scala | 1 - .../state/HDFSBackedStateStoreProvider.scala | 14 +++++- .../state/RocksDBStateStoreProvider.scala | 45 +++++++++++++++---- .../streaming/state/StateStore.scala | 5 ++- .../streaming/state/StateStoreRDD.scala | 28 ++---------- .../execution/streaming/state/package.scala | 13 ------ .../streaming/state/StateStoreRDDSuite.scala | 16 ++++--- 7 files changed, 65 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 4aa95ad42ec7f..101eb1e51d83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -204,7 +204,6 @@ class StatePartitionReader( } override def close(): Unit = { - store.abort() super.close() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 20189c8007ef7..6008cb9662a45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -116,6 +116,15 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with case object ABORTED extends STATE case object RELEASED extends STATE + + Option(TaskContext.get()).foreach { ctxt => + ctxt.addTaskCompletionListener[Unit](ctx => { + if (state == UPDATING) { + abort() + } + }) + } + private val newVersion = version + 1 @volatile private var state: STATE = UPDATING private val finalDeltaFile: Path = deltaFile(newVersion) @@ -960,7 +969,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with * @param endVersion checkpoint version to end with * @return [[HDFSBackedStateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean): StateStore = { val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion) logInfo(log"Retrieved snapshot at version " + log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index ad906ed896c2e..ad31b430195db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + class RocksDBStateStore(lastVersion: Long, var readOnly: Boolean) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE @@ -58,6 +58,30 @@ private[sql] class RocksDBStateStoreProvider override def version: Long = lastVersion + Option(TaskContext.get()).foreach { ctxt => + ctxt.addTaskCompletionListener[Unit]( ctx => { + try { + if (state == UPDATING) { + if (readOnly) { + release() // Only release, do not throw an error because we rely on + // CompletionListener to release for read-only store in + // mapPartitionsWithReadStateStore. + } else { + abort() // Abort since this is an error if stateful task completes + } + } + } catch { + case NonFatal(e) => + logWarning("Failed to abort state store", e) + } + }) + + ctxt.addTaskFailureListener( (_, _) => { + abort() // Either the store is already aborted (this is a no-op) or + // we need to abort it. + }) + } + override def createColFamilyIfAbsent( colFamilyName: String, keySchema: StructType, @@ -368,6 +392,7 @@ private[sql] class RocksDBStateStoreProvider } override def release(): Unit = { + assert(readOnly, "Release can only be called on a read-only store") if (state != RELEASED) { logInfo(log"Releasing ${MDC(VERSION_NUM, version + 1)} " + log"for ${MDC(STATE_STORE_ID, id)}") @@ -495,14 +520,15 @@ private[sql] class RocksDBStateStoreProvider // Create or reuse store instance existingStore match { - case Some(store: RocksDBStateStore) => + case Some(store: ReadStateStore) if store.isInstanceOf[RocksDBStateStore] => // Mark store as being used for write operations - StateStoreThreadLocalTracker.setUsedForWriteStore(true) - store + val rocksDBStateStore = store.asInstanceOf[RocksDBStateStore] + rocksDBStateStore.readOnly = readOnly + rocksDBStateStore.asInstanceOf[StateStore] case None => // Create new store instance - new RocksDBStateStore(version) - // No need for error case here since we validated earlier + new RocksDBStateStore(version, readOnly) + case _ => null // No need for error case here since we validated earlier } } catch { case e: Throwable => @@ -626,7 +652,8 @@ private[sql] class RocksDBStateStoreProvider * @param endVersion checkpoint version to end with * @return [[StateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean): StateStore = { try { if (snapshotVersion < 1) { throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion) @@ -635,7 +662,7 @@ private[sql] class RocksDBStateStoreProvider throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion) } rocksDB.loadFromSnapshot(snapshotVersion, endVersion) - new RocksDBStateStore(endVersion) + new RocksDBStateStore(endVersion, readOnly) } catch { case e: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 3097b07756bcb..d70e75482e6e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -759,7 +759,8 @@ trait SupportsFineGrainedReplay { * @param snapshotVersion checkpoint version of the snapshot to start with * @param endVersion checkpoint version to end with */ - def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore + def replayStateFromSnapshot( + snapshotVersion: Long, endVersion: Long, readOnly: Boolean = false): StateStore /** * Return an instance of [[ReadStateStore]] representing state data of the given version. @@ -772,7 +773,7 @@ trait SupportsFineGrainedReplay { * @param endVersion checkpoint version to end with */ def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = { - new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion)) + new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion, readOnly = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index d78c5229e0ac2..b01aed8abcd66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -33,31 +33,15 @@ import org.apache.spark.util.SerializableConfiguration */ object StateStoreThreadLocalTracker { /** Case class to hold both the store and its usage state */ - case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false) - private val storeInfo: ThreadLocal[StoreInfo] = new ThreadLocal[StoreInfo] + private val storeInfo: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore] def setStore(store: ReadStateStore): Unit = { - Option(storeInfo.get()) match { - case Some(info) => storeInfo.set(info.copy(store = store)) - case None => storeInfo.set(StoreInfo(store)) - } + storeInfo.set(store) } def getStore: Option[ReadStateStore] = { - Option(storeInfo.get()).map(_.store) - } - - def setUsedForWriteStore(used: Boolean): Unit = { - Option(storeInfo.get()) match { - case Some(info) => - storeInfo.set(info.copy(usedForWriteStore = used)) - case None => // If there's no store set, we don't need to track usage - } - } - - def isUsedForWriteStore: Boolean = { - Option(storeInfo.get()).exists(_.usedForWriteStore) + Option(storeInfo.get()) } def clearStore(): Unit = { @@ -171,16 +155,12 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStoreThreadLocalTracker.getStore match { case Some(readStateStore: ReadStateStore) => - val writeStore = StateStore.getWriteStore(readStateStore, storeProviderId, + StateStore.getWriteStore(readStateStore, storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, uniqueId.map(_.apply(partition.index).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) - if (writeStore.equals(readStateStore)) { - StateStoreThreadLocalTracker.setUsedForWriteStore(true) - } - writeStore case None => StateStore.get( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index bcb8190f7531a..651c180299a9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType -import org.apache.spark.util.TaskFailureListener package object state { @@ -68,15 +67,6 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - val ctxt = TaskContext.get() - ctxt.addTaskCompletionListener[Unit](_ => { - if (!store.hasCommitted) store.abort() - }) - ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { - store.abort() - } - }) cleanedF(store, iter) } @@ -118,9 +108,6 @@ package object state { // Clean up the state store. val ctxt = TaskContext.get() ctxt.addTaskCompletionListener[Unit](_ => { - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { - store.release() - } StateStoreThreadLocalTracker.clearStore() }) cleanedF(store, iter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6095b26ecd6fe..2247fd21da597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -259,6 +259,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { // Create input data for our chained operations val inputData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + var mappedReadStore: ReadStateStore = null + var mappedWriteStore: StateStore = null + // Chain operations: first read with ReadStateStore, then write with StateStore val chainedResults = inputData // First pass: read-only state store access @@ -270,6 +273,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { spark.sessionState, Some(castToImpl(spark).streams.stateStoreCoordinator) ) { (readStore, iter) => + mappedReadStore = readStore + // Read values and store them for later verification val inputItems = iter.toSeq // Materialize the input data @@ -296,7 +301,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { ) { (writeStore, writeIter) => if (writeIter.hasNext) { val (readValues, allStoreValues, originalItems) = writeIter.next() - val usedForWriteStore = StateStoreThreadLocalTracker.isUsedForWriteStore + mappedWriteStore = writeStore // Get all existing values from the write store to verify reuse val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq @@ -310,15 +315,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { writeStore.commit() // Return all collected information for verification - Iterator((readValues, allStoreValues, storeValues, usedForWriteStore)) + Iterator((readValues, allStoreValues, storeValues)) } else { Iterator.empty } } // Collect the results - val (readValues, initialStoreState, writeStoreValues, - storeWasReused) = chainedResults.collect().head + val (readValues, initialStoreState, writeStoreValues) = chainedResults.collect().head // Verify read results assert(readValues.toSet === Set( @@ -333,8 +337,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { // Verify the existing values in the write store (should be the same as initial state) assert(writeStoreValues.toSet === Set((("a", 0), 1), (("b", 0), 2))) - // Verify the thread local flag indicates reuse - assert(storeWasReused, + // Verify that the same store was used for both read and write operations + assert(mappedReadStore == mappedWriteStore, "StateStoreThreadLocalTracker should indicate the read store was reused") // Create another ReadStateStoreRDD to verify the final state (version 2) From 18793fd48c8f3454158e4399569a6fe0b5a7e586 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 9 May 2025 11:05:56 -0700 Subject: [PATCH 18/19] merging --- .../state/RocksDBStateStoreProvider.scala | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 9188b25794746..b7f6770a414d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -511,33 +511,25 @@ private[sql] class RocksDBStateStoreProvider } } - try { - // Load RocksDB store - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, - readOnly = readOnly) - - // Create or reuse store instance - existingStore match { - case Some(store: ReadStateStore) if store.isInstanceOf[RocksDBStateStore] => - // Mark store as being used for write operations - val rocksDBStateStore = store.asInstanceOf[RocksDBStateStore] - rocksDBStateStore.readOnly = readOnly - rocksDBStateStore.asInstanceOf[StateStore] - case None => - // Create new store instance - new RocksDBStateStore(version, readOnly) - case _ => null // No need for error case here since we validated earlier - } - } catch { - case e: Throwable => - throw QueryExecutionErrors.cannotLoadStore(e) + // Load RocksDB store + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = readOnly) + + // Create or reuse store instance + existingStore match { + case Some(store: ReadStateStore) if store.isInstanceOf[RocksDBStateStore] => + // Mark store as being used for write operations + val rocksDBStateStore = store.asInstanceOf[RocksDBStateStore] + rocksDBStateStore.readOnly = readOnly + rocksDBStateStore.asInstanceOf[StateStore] + case None => + // Create new store instance + new RocksDBStateStore(version, readOnly) + case _ => null // No need for error case here since we validated earlier } } catch { - case e: SparkException - if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) => - throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( stateStoreId.toString, From 3270b8418d8bbd611832fc9375f24265c57fdfda Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 9 May 2025 11:07:18 -0700 Subject: [PATCH 19/19] removing SparkException --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index b7f6770a414d5..0ef39a2afc2fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec