Skip to content

Commit a4b7c10

Browse files
committed
refactoring
1 parent cf57d27 commit a4b7c10

File tree

5 files changed

+164
-42
lines changed

5 files changed

+164
-42
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

+12
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,18 @@ class RocksDB(
10131013
}
10141014
}
10151015

1016+
/**
1017+
* Releases resources associated with this RocksDB instance without rolling back changes.
1018+
*
1019+
* This method is used in the read-then-write pattern where:
1020+
* 1. A read-only store is opened to retrieve existing state
1021+
* 2. The same store is converted to a writable store
1022+
* 3. After the write store commits, we need to release resources without rolling back
1023+
* the changes that were just committed
1024+
*
1025+
* Unlike abort() which rolls back uncommitted changes, release() simply releases
1026+
* resources and locks without affecting the state data.
1027+
*/
10161028
def release(): Unit = {
10171029
release(LoadStore)
10181030
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

+48-26
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,20 @@ private[sql] class RocksDBStateStoreProvider
4444
import RocksDBStateStoreProvider._
4545

4646
class RocksDBStateStore(lastVersion: Long) extends StateStore {
47-
/** Trait and classes representing the internal state of the store */
47+
/**
48+
* Trait and classes representing the internal state of the store
49+
*
50+
* State transitions:
51+
* - Initial state: UPDATING
52+
* - UPDATING -> COMMITTED: After successful commit()
53+
* - UPDATING -> ABORTED: After abort() or failed commit()
54+
* - UPDATING -> RELEASED: After release() without committing changes
55+
* - COMMITTED -> RELEASED: After release() following a successful commit
56+
* - ABORTED -> RELEASED: After release() following an abort
57+
*
58+
* The RELEASED state is terminal and indicates that resources have been released
59+
* without affecting the underlying data (unlike ABORTED which rolls back changes).
60+
*/
4861
trait STATE
4962
case object UPDATING extends STATE
5063
case object COMMITTED extends STATE
@@ -454,6 +467,15 @@ private[sql] class RocksDBStateStoreProvider
454467

455468
override def stateStoreId: StateStoreId = stateStoreId_
456469

470+
/**
471+
* Creates and returns a state store with the specified parameters.
472+
*
473+
* @param version The version of the state store to load
474+
* @param uniqueId Optional unique identifier for checkpoint
475+
* @param readOnly Whether to open the store in read-only mode
476+
* @param existingStore Optional existing store to reuse instead of creating a new one
477+
* @return The loaded state store
478+
*/
457479
/**
458480
* Creates and returns a state store with the specified parameters.
459481
*
@@ -468,32 +490,30 @@ private[sql] class RocksDBStateStoreProvider
468490
uniqueId: Option[String],
469491
readOnly: Boolean,
470492
existingStore: Option[ReadStateStore] = None): StateStore = {
493+
if (version < 0) {
494+
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
495+
}
496+
471497
try {
472-
if (version < 0) {
473-
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
474-
}
475-
try {
476-
// Load RocksDB store
477-
rocksDB.load(
478-
version,
479-
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
480-
readOnly = readOnly)
481-
482-
// Return appropriate store instance
483-
existingStore match {
484-
case Some(stateStore: RocksDBStateStore) =>
485-
// Reuse existing store for getWriteStore case
486-
stateStore
487-
case Some(_) =>
488-
throw new IllegalArgumentException("Existing store must be a RocksDBStateStore")
489-
case None =>
490-
// Create new store instance for getStore/getReadStore cases
491-
new RocksDBStateStore(version)
492-
}
493-
} catch {
494-
case e: Throwable =>
495-
throw e
498+
// Load RocksDB store
499+
rocksDB.load(
500+
version,
501+
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
502+
readOnly = readOnly)
503+
504+
// Return appropriate store instance
505+
val stateStore = existingStore match {
506+
case Some(stateStore: RocksDBStateStore) =>
507+
// Reuse existing store for getWriteStore case
508+
stateStore
509+
case Some(_) =>
510+
throw new IllegalArgumentException("Existing store must be a RocksDBStateStore")
511+
case None =>
512+
// Create new store instance for getStore/getReadStore cases
513+
new RocksDBStateStore(version)
496514
}
515+
516+
stateStore
497517
} catch {
498518
case e: SparkException
499519
if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
@@ -503,7 +523,9 @@ private[sql] class RocksDBStateStoreProvider
503523
stateStoreId.toString,
504524
"ROCKSDB_STORE_PROVIDER",
505525
e)
506-
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
526+
case e: Throwable =>
527+
logError(s"Failed to load state store version $version with uniqueId $uniqueId", e)
528+
throw QueryExecutionErrors.cannotLoadStore(e)
507529
}
508530
}
509531
override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

+36-2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ trait ReadStateStore {
118118
*/
119119
def abort(): Unit
120120

121+
/**
122+
* Release resources associated with this state store without rolling back changes.
123+
*
124+
* Unlike `abort()` which rolls back uncommitted changes, `release()` simply releases
125+
* resources and locks without affecting the state data. This is particularly important
126+
* in the read-then-write pattern where:
127+
*
128+
* 1. A read-only store is opened to retrieve existing state
129+
* 2. The same store is converted to a writable store using `getWriteStore()`
130+
* 3. After the write store commits, we need to release resources without rolling back
131+
* the changes that were just committed
132+
*
133+
* Implementations should ensure that:
134+
* 1. Any locks or resources held by this store are released
135+
* 2. No uncommitted changes are rolled back (unlike `abort()`)
136+
* 3. The method is idempotent and safe to call multiple times
137+
*/
121138
def release(): Unit
122139
}
123140

@@ -236,7 +253,7 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
236253

237254
override def abort(): Unit = store.abort()
238255

239-
override def release(): Unit = {}
256+
override def release(): Unit = store.release()
240257

241258
override def prefixScan(prefixKey: UnsafeRow,
242259
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
@@ -569,6 +586,24 @@ trait StateStoreProvider {
569586
version: Long,
570587
stateStoreCkptId: Option[String] = None): StateStore
571588

589+
/**
590+
* Converts a read-only state store to a writable state store.
591+
*
592+
* This method is a key part of the read-then-write pattern optimization that avoids
593+
* lock contention issues when the same state store needs to be accessed for both
594+
* reading and writing. Instead of opening two separate connections (which would block
595+
* with lock hardening), this method reuses the existing read store connection.
596+
*
597+
* Implementations should ensure that:
598+
* 1. The returned store has the same version as the input read store
599+
* 2. The returned store has access to all state that was visible to the read store
600+
* 3. The returned store can be used for both reading and writing operations
601+
*
602+
* @param readStore The read-only state store to convert to a writable store
603+
* @param version The version of the state store (should match readStore.version)
604+
* @param uniqueId Optional unique identifier for checkpoint
605+
* @return A writable state store that reuses the same underlying connection
606+
*/
572607
def getWriteStore(
573608
readStore: ReadStateStore,
574609
version: Long,
@@ -1258,4 +1293,3 @@ object StateStore extends Logging {
12581293
}
12591294
}
12601295
}
1261-

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala

+65-14
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ import org.apache.spark.util.SerializableConfiguration
4040
* 2. The same state store is then converted to read-write mode for updates
4141
* 3. This avoids having two separate open connections to the same state store
4242
* which would cause blocking or contention issues
43+
*
44+
* This pattern is particularly important for stateful aggregations where:
45+
* - StateStoreRestoreExec first reads previous state using a read-only store
46+
* - StateStoreSaveExec then updates the state using a writable store
47+
*
48+
* Without this optimization, the following pattern would cause contention:
49+
* readStore.acquire()
50+
* writeStore.acquire() // This would block with lock hardening changes
51+
* writeStore.commit()
52+
* readStore.abort()
53+
*
54+
* With this optimization, the pattern becomes:
55+
* readStore = getReadStore()
56+
* writeStore = getWriteStore(readStore) // Reuses the same store connection
57+
* writeStore.commit()
58+
* // No need to abort/release readStore as it's the same underlying store
4359
*/
4460
trait StateStoreRDDProvider {
4561
/**
@@ -108,9 +124,24 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
108124
extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId,
109125
sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
110126

111-
// Using a ConcurrentHashMap to track state stores by partition ID
112-
@transient private lazy val partitionStores =
113-
new java.util.concurrent.ConcurrentHashMap[Int, ReadStateStore]()
127+
// Using a bounded concurrent map to track state stores by partition ID
128+
// This prevents memory leaks for long-running tasks by limiting the maximum size
129+
@transient private lazy val partitionStores = {
130+
val maxSize = 100 // Maximum number of state stores to cache
131+
java.util.Collections.synchronizedMap(
132+
new java.util.LinkedHashMap[Int, ReadStateStore](16, 0.75f, true) {
133+
override def removeEldestEntry(
134+
eldest: java.util.Map.Entry[Int, ReadStateStore]): Boolean = {
135+
val tooMany = size() > maxSize
136+
if (tooMany) {
137+
// Release resources for the state store being evicted
138+
eldest.getValue.release()
139+
}
140+
tooMany
141+
}
142+
}
143+
)
144+
}
114145

115146
override def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore] = {
116147
Option(partitionStores.get(partitionId))
@@ -182,22 +213,42 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
182213
* This is particularly important for stateful aggregations where StateStoreRestoreExec
183214
* first reads previous state and StateStoreSaveExec then updates it.
184215
*
185-
* The method performs a depth-first search through the RDD dependency graph.
216+
* The method performs an optimized depth-first search through the RDD dependency graph,
217+
* prioritizing paths that are more likely to contain state store providers.
186218
*
187219
* @param rdd The starting RDD to search from
188220
* @return Some(provider) if a StateStoreRDDProvider is found in the lineage, None otherwise
189221
*/
190222
private def findStateStoreProvider(rdd: RDD[_]): Option[StateStoreRDDProvider] = {
191-
rdd match {
192-
case null => None
193-
case provider: StateStoreRDDProvider => Some(provider)
194-
case _ if rdd.dependencies.isEmpty => None
195-
case _ =>
196-
// Search all dependencies
197-
rdd.dependencies.view
198-
.map(dep => findStateStoreProvider(dep.rdd))
199-
.find(_.isDefined)
200-
.flatten
223+
// Early termination conditions
224+
if (rdd == null) return None
225+
226+
// Check if the current RDD is a provider
227+
if (rdd.isInstanceOf[StateStoreRDDProvider]) {
228+
return Some(rdd.asInstanceOf[StateStoreRDDProvider])
229+
}
230+
231+
// If no dependencies, we can't find a provider
232+
if (rdd.dependencies.isEmpty) return None
233+
234+
// Prioritize narrow dependencies over wide dependencies
235+
// Narrow dependencies are more likely to preserve the state store provider lineage
236+
val (narrowDeps, wideDeps) = rdd.dependencies.partition(_.isInstanceOf[NarrowDependency[_]])
237+
238+
// First search through narrow dependencies
239+
val narrowResult = narrowDeps.view
240+
.map(dep => findStateStoreProvider(dep.rdd))
241+
.find(_.isDefined)
242+
.flatten
243+
244+
if (narrowResult.isDefined) {
245+
narrowResult
246+
} else {
247+
// If not found in narrow dependencies, try wide dependencies
248+
wideDeps.view
249+
.map(dep => findStateStoreProvider(dep.rdd))
250+
.find(_.isDefined)
251+
.flatten
201252
}
202253
}
203254

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala

+3
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ package object state {
116116
})
117117
taskContext.addTaskFailureListener(new TaskFailureListener {
118118
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
119+
// On task failure, we need to abort to roll back any uncommitted changes
120+
// We don't call release() here because it would leave the state in an inconsistent state
121+
// abort() ensures proper cleanup and rollback of uncommitted changes
119122
store.abort()
120123
}
121124
})

0 commit comments

Comments
 (0)