Skip to content

Commit 48d4901

Browse files
committed
Add a flag to force stats collection during query optimizations
1 parent 2a57836 commit 48d4901

File tree

3 files changed

+179
-4
lines changed

3 files changed

+179
-4
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ trait DeltaSQLConfBase extends DeltaSQLConfUtils {
392392
.booleanConf
393393
.createWithDefault(true)
394394

395+
val DELTA_ALWAYS_COLLECT_STATS =
396+
buildConf("alwaysCollectStats.enabled")
397+
.internal()
398+
.doc("When true, row counts are collected from file statistics even when there are no " +
399+
"data filters. This is useful for ensuring PreparedDeltaFileIndex always has row count " +
400+
"information available. Note: this may have a small performance overhead as it requires " +
401+
"summing numRecords from all files.")
402+
.booleanConf
403+
.createWithDefault(false)
404+
395405
val DELTA_LIMIT_PUSHDOWN_ENABLED =
396406
buildConf("stats.limitPushdown.enabled")
397407
.internal()

spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,10 @@ trait DataSkippingReaderBase
12381238
partitionFilters: Seq[Expression],
12391239
keepNumRecords: Boolean): (Seq[AddFile], DataSize) = recordFrameProfile(
12401240
"Delta", "DataSkippingReader.filterOnPartitions") {
1241-
val df = if (keepNumRecords) {
1241+
val forceCollectRowCount =
1242+
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS)
1243+
val shouldCollectStats = keepNumRecords || forceCollectRowCount
1244+
val df = if (shouldCollectStats) {
12421245
// use withStats instead of allFiles so the `stats` column is already parsed
12431246
val filteredFiles =
12441247
DeltaLog.filterFileList(metadata.partitionSchema, withStats, partitionFilters)
@@ -1253,7 +1256,26 @@ trait DataSkippingReaderBase
12531256
}
12541257
val files = convertDataFrameToAddFiles(df)
12551258
val sizeInBytesByPartitionFilters = files.map(_.size).sum
1256-
files.toSeq -> DataSize(Some(sizeInBytesByPartitionFilters), None, Some(files.size))
1259+
// Compute row count if we have stats available and forceCollectRowCount is enabled
1260+
val rowCount = if (forceCollectRowCount) {
1261+
sumRowCounts(files)
1262+
} else {
1263+
None
1264+
}
1265+
files.toSeq -> DataSize(Some(sizeInBytesByPartitionFilters), rowCount, Some(files.size))
1266+
}
1267+
1268+
/**
1269+
* Sums up the numPhysicalRecords from the given AddFile objects.
1270+
* Returns None if any file is missing stats (to indicate incomplete row count).
1271+
*/
1272+
private def sumRowCounts(files: Seq[AddFile]): Option[Long] = {
1273+
files.foldLeft(Option(0L)) { (accOpt, file) =>
1274+
for {
1275+
acc <- accOpt
1276+
count <- file.numPhysicalRecords
1277+
} yield acc + count
1278+
}
12571279
}
12581280

12591281
/**
@@ -1310,13 +1332,23 @@ trait DataSkippingReaderBase
13101332
if (filters == Seq(TrueLiteral) || filters.isEmpty || schema.isEmpty) {
13111333
recordDeltaOperation(deltaLog, "delta.skipping.none") {
13121334
// When there are no filters we can just return allFiles with no extra processing
1335+
val forceCollectRowCount =
1336+
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS)
1337+
val shouldCollectStats = keepNumRecords || forceCollectRowCount
1338+
val files = getAllFiles(shouldCollectStats)
1339+
// Compute row count if forceCollectRowCount is enabled
1340+
val rowCount = if (forceCollectRowCount) {
1341+
sumRowCounts(files)
1342+
} else {
1343+
None
1344+
}
13131345
val dataSize = DataSize(
13141346
bytesCompressed = sizeInBytesIfKnown,
1315-
rows = None,
1347+
rows = rowCount,
13161348
files = numOfFilesIfKnown)
13171349
return DeltaScan(
13181350
version = version,
1319-
files = getAllFiles(keepNumRecords),
1351+
files = files,
13201352
total = dataSize,
13211353
partition = dataSize,
13221354
scanned = dataSize)(
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright (2021) The Delta Lake Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.apache.spark.sql.delta.stats
18+
19+
// scalastyle:off import.ordering.noEmptyLine
20+
import org.apache.spark.sql.delta.{DeltaLog, DeltaTable}
21+
import org.apache.spark.sql.delta.sources.DeltaSQLConf
22+
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
23+
24+
import org.apache.spark.sql.{DataFrame, QueryTest}
25+
import org.apache.spark.sql.functions._
26+
27+
/**
28+
* Test suite to verify when preparedScan.scanned.rows is populated in PreparedDeltaFileIndex,
29+
* and the behavior of the DELTA_ALWAYS_COLLECT_STATS flag.
30+
*/
31+
class PreparedDeltaFileIndexRowCountSuite
32+
extends QueryTest
33+
with DeltaSQLCommandTest {
34+
35+
import testImplicits._
36+
37+
private def getDeltaScan(df: DataFrame): DeltaScan = {
38+
val scans = df.queryExecution.optimizedPlan.collect {
39+
case DeltaTable(prepared: PreparedDeltaFileIndex) => prepared.preparedScan
40+
}
41+
assert(scans.size == 1, s"Expected 1 DeltaScan, found ${scans.size}")
42+
scans.head
43+
}
44+
45+
/**
46+
* Test utility that creates a partitioned Delta table and verifies scanned.rows behavior.
47+
*
48+
* @param alwaysCollectStats value of the DELTA_ALWAYS_COLLECT_STATS flag
49+
* @param queryTransform function to transform the base DataFrame (apply filters)
50+
* @param expectedRowsDefined whether scanned.rows should be defined
51+
* @param expectedRowCount expected row count if defined (None to skip validation)
52+
*/
53+
private def testRowCountBehavior(
54+
alwaysCollectStats: Boolean,
55+
queryTransform: DataFrame => DataFrame,
56+
expectedRowsDefined: Boolean,
57+
expectedRowCount: Option[Long] = None): Unit = {
58+
withTempDir { dir =>
59+
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "true") {
60+
spark.range(100).toDF("id")
61+
.withColumn("part", $"id" % 4)
62+
.repartition(4)
63+
.write.format("delta").partitionBy("part").save(dir.getAbsolutePath)
64+
}
65+
66+
DeltaLog.clearCache()
67+
68+
withSQLConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS.key -> alwaysCollectStats.toString) {
69+
val df = spark.read.format("delta").load(dir.getAbsolutePath)
70+
val scan = getDeltaScan(queryTransform(df))
71+
72+
if (expectedRowsDefined) {
73+
assert(scan.scanned.rows.isDefined, "scanned.rows should be defined")
74+
expectedRowCount.foreach { expected =>
75+
assert(scan.scanned.rows.get == expected,
76+
s"Expected $expected rows, got ${scan.scanned.rows.get}")
77+
}
78+
} else {
79+
assert(scan.scanned.rows.isEmpty, "scanned.rows should be None")
80+
}
81+
}
82+
}
83+
}
84+
85+
// Define query cases: (name, transform function, always collects rows)
86+
// Note: In the Edge code path, DataSkippingReaderEdge.filterOnPartitions always collects
87+
// row counts for partition filter cases (see keepNumRecords = true in that method).
88+
// So only "no filter" and "TrueLiteral filter" depend on the alwaysCollectStats flag.
89+
private val queryCases: Seq[(String, DataFrame => DataFrame, Boolean)] = Seq(
90+
("no filter", identity[DataFrame], false),
91+
("TrueLiteral filter", _.where(lit(true)), false),
92+
("partition filter only", _.where($"part" === 1), false),
93+
("data filter", _.where($"id" === 50), true),
94+
("partition + data filter", _.where($"part" === 1).where($"id" === 50), true)
95+
)
96+
97+
// Grid test: all query cases x flag values
98+
for {
99+
(caseName, queryTransform, alwaysCollectsRows) <- queryCases
100+
alwaysCollectStats <- Seq(false, true)
101+
} {
102+
val flagDesc = s"alwaysCollectStats=$alwaysCollectStats"
103+
// If the query type always collects rows, rows is always defined; otherwise depends on flag
104+
val expectedRowsDefined = alwaysCollectsRows || alwaysCollectStats
105+
106+
test(s"$caseName - $flagDesc") {
107+
testRowCountBehavior(
108+
alwaysCollectStats = alwaysCollectStats,
109+
queryTransform = queryTransform,
110+
expectedRowsDefined = expectedRowsDefined
111+
)
112+
}
113+
}
114+
115+
test("alwaysCollectStats with missing stats returns None") {
116+
withTempDir { dir =>
117+
// Create table without stats
118+
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") {
119+
spark.range(100).toDF("id")
120+
.write.format("delta").save(dir.getAbsolutePath)
121+
}
122+
123+
DeltaLog.clearCache()
124+
125+
withSQLConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS.key -> "true") {
126+
val df = spark.read.format("delta").load(dir.getAbsolutePath)
127+
val scan = getDeltaScan(df)
128+
assert(scan.scanned.rows.isEmpty, "scanned.rows should be None when stats are missing")
129+
}
130+
}
131+
}
132+
133+
}

0 commit comments

Comments
 (0)