Skip to content

Commit bbc0b1a

Browse files
committed
Add a flag to force stats collection during query optimizations
1 parent 2a57836 commit bbc0b1a

File tree

3 files changed

+176
-3
lines changed

3 files changed

+176
-3
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ trait DeltaSQLConfBase extends DeltaSQLConfUtils {
392392
.booleanConf
393393
.createWithDefault(true)
394394

395+
val DELTA_ALWAYS_COLLECT_STATS =
396+
buildConf("alwaysCollectStats.enabled")
397+
.internal()
398+
.doc("When true, row counts are collected from file statistics even when there are no " +
399+
"data filters. This is useful for ensuring PreparedDeltaFileIndex always has row count " +
400+
"information available. Note: this may have a small performance overhead as it requires " +
401+
"summing numRecords from all files.")
402+
.booleanConf
403+
.createWithDefault(false)
404+
395405
val DELTA_LIMIT_PUSHDOWN_ENABLED =
396406
buildConf("stats.limitPushdown.enabled")
397407
.internal()

spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,10 @@ trait DataSkippingReaderBase
12381238
partitionFilters: Seq[Expression],
12391239
keepNumRecords: Boolean): (Seq[AddFile], DataSize) = recordFrameProfile(
12401240
"Delta", "DataSkippingReader.filterOnPartitions") {
1241-
val df = if (keepNumRecords) {
1241+
val forceCollectRowCount =
1242+
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS)
1243+
val shouldCollectStats = keepNumRecords || forceCollectRowCount
1244+
val df = if (shouldCollectStats) {
12421245
// use withStats instead of allFiles so the `stats` column is already parsed
12431246
val filteredFiles =
12441247
DeltaLog.filterFileList(metadata.partitionSchema, withStats, partitionFilters)
@@ -1253,7 +1256,26 @@ trait DataSkippingReaderBase
12531256
}
12541257
val files = convertDataFrameToAddFiles(df)
12551258
val sizeInBytesByPartitionFilters = files.map(_.size).sum
1256-
files.toSeq -> DataSize(Some(sizeInBytesByPartitionFilters), None, Some(files.size))
1259+
// Compute row count if we have stats available and forceCollectRowCount is enabled
1260+
val rowCount = if (forceCollectRowCount) {
1261+
sumRowCounts(files)
1262+
} else {
1263+
None
1264+
}
1265+
files.toSeq -> DataSize(Some(sizeInBytesByPartitionFilters), rowCount, Some(files.size))
1266+
}
1267+
1268+
/**
1269+
* Sums up the numPhysicalRecords from the given AddFile objects.
1270+
* Returns None if any file is missing stats (to indicate incomplete row count).
1271+
*/
1272+
private def sumRowCounts(files: Seq[AddFile]): Option[Long] = {
1273+
files.foldLeft(Option(0L)) { (accOpt, file) =>
1274+
for {
1275+
acc <- accOpt
1276+
count <- file.numPhysicalRecords
1277+
} yield acc + count
1278+
}
12571279
}
12581280

12591281
/**
@@ -1310,9 +1332,18 @@ trait DataSkippingReaderBase
13101332
if (filters == Seq(TrueLiteral) || filters.isEmpty || schema.isEmpty) {
13111333
recordDeltaOperation(deltaLog, "delta.skipping.none") {
13121334
// When there are no filters we can just return allFiles with no extra processing
1335+
val forceCollectRowCount =
1336+
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS)
1337+
val shouldCollectStats = keepNumRecords || forceCollectRowCount
1338+
// Compute row count if forceCollectRowCount is enabled
1339+
val rowCount = if (forceCollectRowCount) {
1340+
sumRowCounts(files)
1341+
} else {
1342+
None
1343+
}
13131344
val dataSize = DataSize(
13141345
bytesCompressed = sizeInBytesIfKnown,
1315-
rows = None,
1346+
rows = rowCount,
13161347
files = numOfFilesIfKnown)
13171348
return DeltaScan(
13181349
version = version,
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright (2021) The Delta Lake Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.apache.spark.sql.delta.stats
18+
19+
// scalastyle:off import.ordering.noEmptyLine
20+
import org.apache.spark.sql.delta.sources.DeltaSQLConf
21+
22+
import org.apache.spark.sql.{DataFrame, QueryTest}
23+
import org.apache.spark.sql.functions._
24+
import org.apache.spark.sql.test.SharedSparkSession
25+
26+
/**
27+
* Test suite to verify when preparedScan.scanned.rows is populated in PreparedDeltaFileIndex,
28+
* and the behavior of the DELTA_ALWAYS_COLLECT_STATS flag.
29+
*/
30+
class PreparedDeltaFileIndexRowCountSuite
31+
extends QueryTest
32+
with SharedSparkSession
33+
34+
import testImplicits._
35+
36+
private def getDeltaScan(df: DataFrame): DeltaScan = {
37+
val scans = df.queryExecution.optimizedPlan.collect {
38+
case DeltaTable(prepared: PreparedDeltaFileIndex) => prepared.preparedScan
39+
}
40+
assert(scans.size == 1, s"Expected 1 DeltaScan, found ${scans.size}")
41+
scans.head
42+
}
43+
44+
/**
45+
* Test utility that creates a partitioned Delta table and verifies scanned.rows behavior.
46+
*
47+
* @param alwaysCollectStats value of the DELTA_ALWAYS_COLLECT_STATS flag
48+
* @param queryTransform function to transform the base DataFrame (apply filters)
49+
* @param expectedRowsDefined whether scanned.rows should be defined
50+
* @param expectedRowCount expected row count if defined (None to skip validation)
51+
*/
52+
private def testRowCountBehavior(
53+
alwaysCollectStats: Boolean,
54+
queryTransform: DataFrame => DataFrame,
55+
expectedRowsDefined: Boolean,
56+
expectedRowCount: Option[Long] = None): Unit = {
57+
withTempDir { dir =>
58+
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "true") {
59+
spark.range(100).toDF("id")
60+
.withColumn("part", $"id" % 4)
61+
.repartition(4)
62+
.write.format("delta").partitionBy("part").save(dir.getAbsolutePath)
63+
}
64+
65+
DeltaLog.clearCache()
66+
67+
withSQLConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS.key -> alwaysCollectStats.toString) {
68+
val df = spark.read.format("delta").load(dir.getAbsolutePath)
69+
val scan = getDeltaScan(queryTransform(df))
70+
71+
if (expectedRowsDefined) {
72+
assert(scan.scanned.rows.isDefined, "scanned.rows should be defined")
73+
expectedRowCount.foreach { expected =>
74+
assert(scan.scanned.rows.get == expected,
75+
s"Expected $expected rows, got ${scan.scanned.rows.get}")
76+
}
77+
} else {
78+
assert(scan.scanned.rows.isEmpty, "scanned.rows should be None")
79+
}
80+
}
81+
}
82+
}
83+
84+
// Define query cases: (name, transform function, always collects rows)
85+
// Note: In the Edge code path, DataSkippingReaderEdge.filterOnPartitions always collects
86+
// row counts for partition filter cases (see keepNumRecords = true in that method).
87+
// So only "no filter" and "TrueLiteral filter" depend on the alwaysCollectStats flag.
88+
private val queryCases: Seq[(String, DataFrame => DataFrame, Boolean)] = Seq(
89+
("no filter", identity[DataFrame], false),
90+
("TrueLiteral filter", _.where(lit(true)), false),
91+
("partition filter only", _.where($"part" === 1), false),
92+
("data filter", _.where($"id" === 50), true),
93+
("partition + data filter", _.where($"part" === 1).where($"id" === 50), true)
94+
)
95+
96+
// Grid test: all query cases x flag values
97+
for {
98+
(caseName, queryTransform, alwaysCollectsRows) <- queryCases
99+
alwaysCollectStats <- Seq(false, true)
100+
} {
101+
val flagDesc = s"alwaysCollectStats=$alwaysCollectStats"
102+
// If the query type always collects rows, rows is always defined; otherwise depends on flag
103+
val expectedRowsDefined = alwaysCollectsRows || alwaysCollectStats
104+
105+
test(s"$caseName - $flagDesc") {
106+
testRowCountBehavior(
107+
alwaysCollectStats = alwaysCollectStats,
108+
queryTransform = queryTransform,
109+
expectedRowsDefined = expectedRowsDefined
110+
)
111+
}
112+
}
113+
114+
test("alwaysCollectStats with missing stats returns None") {
115+
withTempDir { dir =>
116+
// Create table without stats
117+
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") {
118+
spark.range(100).toDF("id")
119+
.write.format("delta").save(dir.getAbsolutePath)
120+
}
121+
122+
DeltaLog.clearCache()
123+
124+
withSQLConf(DeltaSQLConf.DELTA_ALWAYS_COLLECT_STATS.key -> "true") {
125+
val df = spark.read.format("delta").load(dir.getAbsolutePath)
126+
val scan = getDeltaScan(df)
127+
assert(scan.scanned.rows.isEmpty, "scanned.rows should be None when stats are missing")
128+
}
129+
}
130+
}
131+
132+
}

0 commit comments

Comments
 (0)