@@ -40,6 +40,22 @@ import org.apache.spark.util.SerializableConfiguration
40
40
* 2. The same state store is then converted to read-write mode for updates
41
41
* 3. This avoids having two separate open connections to the same state store
42
42
* which would cause blocking or contention issues
43
+ *
44
+ * This pattern is particularly important for stateful aggregations where:
45
+ * - StateStoreRestoreExec first reads previous state using a read-only store
46
+ * - StateStoreSaveExec then updates the state using a writable store
47
+ *
48
+ * Without this optimization, the following pattern would cause contention:
49
+ * readStore.acquire()
50
+ * writeStore.acquire() // This would block with lock hardening changes
51
+ * writeStore.commit()
52
+ * readStore.abort()
53
+ *
54
+ * With this optimization, the pattern becomes:
55
+ * readStore = getReadStore()
56
+ * writeStore = getWriteStore(readStore) // Reuses the same store connection
57
+ * writeStore.commit()
58
+ * // No need to abort/release readStore as it's the same underlying store
43
59
*/
44
60
trait StateStoreRDDProvider {
45
61
/**
@@ -108,9 +124,24 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
108
124
extends BaseStateStoreRDD [T , U ](dataRDD, checkpointLocation, queryRunId, operatorId,
109
125
sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
110
126
111
- // Using a ConcurrentHashMap to track state stores by partition ID
112
- @ transient private lazy val partitionStores =
113
- new java.util.concurrent.ConcurrentHashMap [Int , ReadStateStore ]()
127
+ // Using a bounded concurrent map to track state stores by partition ID
128
+ // This prevents memory leaks for long-running tasks by limiting the maximum size
129
+ @ transient private lazy val partitionStores = {
130
+ val maxSize = 100 // Maximum number of state stores to cache
131
+ java.util.Collections .synchronizedMap(
132
+ new java.util.LinkedHashMap [Int , ReadStateStore ](16 , 0.75f , true ) {
133
+ override def removeEldestEntry (
134
+ eldest : java.util.Map .Entry [Int , ReadStateStore ]): Boolean = {
135
+ val tooMany = size() > maxSize
136
+ if (tooMany) {
137
+ // Release resources for the state store being evicted
138
+ eldest.getValue.release()
139
+ }
140
+ tooMany
141
+ }
142
+ }
143
+ )
144
+ }
114
145
115
146
override def getStateStoreForPartition (partitionId : Int ): Option [ReadStateStore ] = {
116
147
Option (partitionStores.get(partitionId))
@@ -182,22 +213,42 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
182
213
* This is particularly important for stateful aggregations where StateStoreRestoreExec
183
214
* first reads previous state and StateStoreSaveExec then updates it.
184
215
*
185
- * The method performs a depth-first search through the RDD dependency graph.
216
+ * The method performs an optimized depth-first search through the RDD dependency graph,
217
+ * prioritizing paths that are more likely to contain state store providers.
186
218
*
187
219
* @param rdd The starting RDD to search from
188
220
* @return Some(provider) if a StateStoreRDDProvider is found in the lineage, None otherwise
189
221
*/
190
222
private def findStateStoreProvider (rdd : RDD [_]): Option [StateStoreRDDProvider ] = {
191
- rdd match {
192
- case null => None
193
- case provider : StateStoreRDDProvider => Some (provider)
194
- case _ if rdd.dependencies.isEmpty => None
195
- case _ =>
196
- // Search all dependencies
197
- rdd.dependencies.view
198
- .map(dep => findStateStoreProvider(dep.rdd))
199
- .find(_.isDefined)
200
- .flatten
223
+ // Early termination conditions
224
+ if (rdd == null ) return None
225
+
226
+ // Check if the current RDD is a provider
227
+ if (rdd.isInstanceOf [StateStoreRDDProvider ]) {
228
+ return Some (rdd.asInstanceOf [StateStoreRDDProvider ])
229
+ }
230
+
231
+ // If no dependencies, we can't find a provider
232
+ if (rdd.dependencies.isEmpty) return None
233
+
234
+ // Prioritize narrow dependencies over wide dependencies
235
+ // Narrow dependencies are more likely to preserve the state store provider lineage
236
+ val (narrowDeps, wideDeps) = rdd.dependencies.partition(_.isInstanceOf [NarrowDependency [_]])
237
+
238
+ // First search through narrow dependencies
239
+ val narrowResult = narrowDeps.view
240
+ .map(dep => findStateStoreProvider(dep.rdd))
241
+ .find(_.isDefined)
242
+ .flatten
243
+
244
+ if (narrowResult.isDefined) {
245
+ narrowResult
246
+ } else {
247
+ // If not found in narrow dependencies, try wide dependencies
248
+ wideDeps.view
249
+ .map(dep => findStateStoreProvider(dep.rdd))
250
+ .find(_.isDefined)
251
+ .flatten
201
252
}
202
253
}
203
254
0 commit comments