Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public class SparkScan implements Scan, SupportsReportStatistics, SupportsRuntim
// Planned input files and stats
private List<PartitionedFile> partitionedFiles = new ArrayList<>();
private long totalBytes = 0L;
// Estimated size in bytes accounting for column projection, used for query optimizer cost
// estimation
private long estimatedSizeInBytes = 0L;
private volatile boolean planned = false;

public SparkScan(
Expand Down Expand Up @@ -199,7 +202,7 @@ public Statistics estimateStatistics() {
return new Statistics() {
@Override
public OptionalLong sizeInBytes() {
return OptionalLong.of(totalBytes);
return OptionalLong.of(estimatedSizeInBytes);
}

@Override
Expand All @@ -210,6 +213,45 @@ public OptionalLong numRows() {
};
}

/**
* Computes the estimated size in bytes accounting for column projection.
*
* <p>This mirrors what {@code SizeInBytesOnlyStatsPlanVisitor.visitUnaryNode} (from Spark code)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* would compute for a {@code Project} over a {@code LogicalRelation}: {@code sizeInBytes =
* childSizeInBytes * outputRowSize / childRowSize}
*
* <p>Where:
*
* <ul>
* <li><b>childRowSize</b> = {@code ROW_OVERHEAD + dataSchema + partitionSchema} (equivalent to
* LogicalRelation output)
* <li><b>outputRowSize</b> = {@code ROW_OVERHEAD + readDataSchema + partitionSchema}
* (equivalent to Project output)
* </ul>
*
* <p>This provides consistent statistics with the v1 code path (LogicalRelation + visitUnaryNode
* from Spark code directory).
*
* @param totalBytes the total size in bytes of the planned files (raw physical size)
* @return the estimated size in bytes after accounting for column projection
*/
private long computeEstimatedSizeWithColumnProjection(long totalBytes) {
if (totalBytes <= 0) {
return totalBytes;
}

// Row overhead constant, matching EstimationUtils.getSizePerRow (from Spark)
final int ROW_OVERHEAD = 8;

final long fullSchemaRowSize =
ROW_OVERHEAD + dataSchema.defaultSize() + partitionSchema.defaultSize();
final long outputRowSize = ROW_OVERHEAD + readSchema().defaultSize();

long estimatedBytes = (totalBytes * outputRowSize) / fullSchemaRowSize;

return Math.max(1L, estimatedBytes);
}

/**
* Get the table path from the scan state.
*
Expand Down Expand Up @@ -253,6 +295,9 @@ private void planScanFiles() {
throw new RuntimeException(e);
}
}

// Pre-compute estimated size accounting for column projection
estimatedSizeInBytes = computeEstimatedSizeWithColumnProjection(totalBytes);
}

/** Ensure the scan is planned exactly once in a thread\-safe manner. */
Expand Down Expand Up @@ -321,10 +366,12 @@ public void filter(org.apache.spark.sql.connector.expressions.filter.Predicate[]
}
}

// Update partitionedFiles and totalBytes, if any partition is filtered out
// Update partitionedFiles, totalBytes, and estimatedSizeInBytes if any partition is filtered
// out
if (runtimeFilteredPartitionedFiles.size() < this.partitionedFiles.size()) {
this.partitionedFiles = runtimeFilteredPartitionedFiles;
this.totalBytes = this.partitionedFiles.stream().mapToLong(PartitionedFile::fileSize).sum();
this.estimatedSizeInBytes = computeEstimatedSizeWithColumnProjection(this.totalBytes);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.spark.sql.delta.DeltaOptions;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -228,11 +229,15 @@ protected static void checkSupportsRuntimeFilters(
// make a copy for comparison after DPP
beforeDppFiles = new ArrayList<>(beforeDppFiles);
long beforeDppTotalBytes = getTotalBytes(sparkScan);
long beforeDppEstimatedSize = getEstimatedSizeInBytes(sparkScan);
assert (beforeDppFiles.size() == 5);
// Without column pruning, estimatedSizeInBytes should equal totalBytes
assertEquals(beforeDppTotalBytes, beforeDppEstimatedSize);

sparkScan.filter(runtimeFilters);
List<PartitionedFile> afterDppFiles = getPartitionedFiles(sparkScan);
long afterDppTotalBytes = getTotalBytes(sparkScan);
long afterDppEstimatedSize = getEstimatedSizeInBytes(sparkScan);
assert (beforeDppFiles.containsAll(afterDppFiles));
assert (beforeDppTotalBytes >= afterDppTotalBytes);

Expand All @@ -251,6 +256,8 @@ protected static void checkSupportsRuntimeFilters(
assertEquals(expectedPartitionFilesAfterDpp.size(), afterDppFiles.size());
assertEquals(new HashSet<>(expectedPartitionFilesAfterDpp), new HashSet<>(afterDppFiles));
assertEquals(expectedTotalBytesAfterDpp, afterDppTotalBytes);
// Without column pruning, estimatedSizeInBytes should equal totalBytes after filtering too
assertEquals(afterDppTotalBytes, afterDppEstimatedSize);
}

private static List<PartitionedFile> getPartitionedFiles(SparkScan scan) throws Exception {
Expand All @@ -267,6 +274,13 @@ private static long getTotalBytes(SparkScan scan) throws Exception {
return (long) field.get(scan);
}

private static long getEstimatedSizeInBytes(SparkScan scan) throws Exception {
scan.estimateStatistics(); // ensurePlanned
Field field = SparkScan.class.getDeclaredField("estimatedSizeInBytes");
field.setAccessible(true);
return (long) field.get(scan);
}

// ================================================================================================
// Tests for streaming options validation
// ================================================================================================
Expand Down Expand Up @@ -390,4 +404,141 @@ public void testEqualsWithDifferentFilters() {
assertNotEquals(scan1, scan2);
assertNotEquals(scan1.hashCode(), scan2.hashCode());
}

// ================================================================================================
// Tests for estimated size with column projection
// ================================================================================================

@Test
public void testEstimatedSizeMatchesStatistics() throws Exception {
// Test that estimateStatistics().sizeInBytes() returns the estimatedSizeInBytes field
SparkScanBuilder builder = (SparkScanBuilder) table.newScanBuilder(options);
SparkScan scan = (SparkScan) builder.build();

long estimatedSizeFromStats = scan.estimateStatistics().sizeInBytes().getAsLong();
long estimatedSizeFromField = getEstimatedSizeInBytes(scan);

assertEquals(estimatedSizeFromField, estimatedSizeFromStats);
}

@Test
public void testEstimatedSizeWithColumnPruning() throws Exception {
// Test that with column pruning, estimatedSizeInBytes is computed correctly
// Table schema: (part INT, date STRING, city STRING, name STRING, cnt INT)
// Partition columns: (date STRING, city STRING, part INT)
// Data columns: (name STRING, cnt INT)
//
// Formula: estimatedBytes = (totalBytes * outputRowSize) / fullSchemaRowSize
// Where:
// ROW_OVERHEAD = 8
// dataSchema.defaultSize() = 20 (STRING) + 4 (INT) = 24
// partitionSchema.defaultSize() = 20 + 20 + 4 = 44
// fullSchemaRowSize = 8 + 24 + 44 = 76
//
// With pruning to only 'name' column:
// readDataSchema.defaultSize() = 20 (STRING only)
// readSchema().defaultSize() = 20 + 44 = 64
// outputRowSize = 8 + 64 = 72
// estimatedBytes = (totalBytes * 72) / 76
SparkScanBuilder builder = (SparkScanBuilder) table.newScanBuilder(options);

// Prune columns to only include 'name' (a data column) and partition columns
// This simulates: SELECT name, date, city, part FROM table
StructType prunedSchema =
new StructType()
.add("name", DataTypes.StringType) // only one data column
.add("date", DataTypes.StringType) // partition columns are always included
.add("city", DataTypes.StringType)
.add("part", DataTypes.IntegerType);
builder.pruneColumns(prunedSchema);

SparkScan scan = (SparkScan) builder.build();

long totalBytes = getTotalBytes(scan);
long estimatedSize = getEstimatedSizeInBytes(scan);

// Calculate expected estimated size using the formula
// outputRowSize = 8 + 64 = 72, fullSchemaRowSize = 8 + 24 + 44 = 76
// Note: We don't use Math.max(1, ...) here because totalBytes is guaranteed to be large enough
// (parquet files with actual data) that the division result won't be zero.
long expectedEstimatedSize = (totalBytes * 72) / 76;

assertTrue(totalBytes > 0, "totalBytes should be positive");
assertEquals(
expectedEstimatedSize,
estimatedSize,
String.format(
"estimatedSize should be (totalBytes * 72) / 76 = (%d * 72) / 76 = %d",
totalBytes, expectedEstimatedSize));
}

@Test
public void testEstimatedSizeWithColumnPruningAndFiltering() throws Exception {
// Test that column pruning and runtime filtering work together correctly
// Using same formula as testEstimatedSizeWithColumnPruning:
// estimatedBytes = (totalBytes * 72) / 76
SparkScanBuilder builder = (SparkScanBuilder) table.newScanBuilder(options);

// Prune columns to only include 'name' column
StructType prunedSchema =
new StructType()
.add("name", DataTypes.StringType)
.add("date", DataTypes.StringType)
.add("city", DataTypes.StringType)
.add("part", DataTypes.IntegerType);
builder.pruneColumns(prunedSchema);

SparkScan scan = (SparkScan) builder.build();

// Get initial stats with column pruning
long initialTotalBytes = getTotalBytes(scan);
long initialEstimatedSize = getEstimatedSizeInBytes(scan);

// Verify initial estimated size matches formula
// Note: No Math.max(1, ...) needed - totalBytes from parquet files is large enough
long expectedInitialEstimated = (initialTotalBytes * 72) / 76;
assertEquals(
expectedInitialEstimated,
initialEstimatedSize,
"Initial estimatedSize should match formula");

// Apply a runtime filter
scan.filter(new Predicate[] {cityPredicate}); // city=hz

// After filtering, verify both values are updated correctly
long afterFilterTotalBytes = getTotalBytes(scan);
long afterFilterEstimatedSize = getEstimatedSizeInBytes(scan);

// Verify estimated size matches formula with new totalBytes
long expectedAfterFilterEstimated = (afterFilterTotalBytes * 72) / 76;
assertEquals(
expectedAfterFilterEstimated,
afterFilterEstimatedSize,
"After filter, estimatedSize should match formula with new totalBytes");

// Verify both values were reduced
assertTrue(afterFilterTotalBytes < initialTotalBytes, "totalBytes should be reduced");
assertTrue(afterFilterEstimatedSize < initialEstimatedSize, "estimatedSize should be reduced");
}

@Test
public void testEstimatedSizeZeroAfterFilteringOutAllFiles() throws Exception {
// Test that filtering out all files results in zero for both sizes
SparkScanBuilder builder = (SparkScanBuilder) table.newScanBuilder(options);
SparkScan scan = (SparkScan) builder.build();

// Apply filter that matches nothing
scan.filter(new Predicate[] {negativeCityPredicate}); // city=zz doesn't exist

long afterFilterTotalBytes = getTotalBytes(scan);
long afterFilterEstimatedSize = getEstimatedSizeInBytes(scan);

assertEquals(0, afterFilterTotalBytes, "totalBytes should be 0 after filtering out all files");
assertEquals(
0, afterFilterEstimatedSize, "estimatedSize should be 0 after filtering out all files");
assertEquals(
0,
scan.estimateStatistics().sizeInBytes().getAsLong(),
"Statistics sizeInBytes should be 0 after filtering out all files");
}
}
Loading