diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 567b9ca4f3..703284bead 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -65,6 +65,8 @@ tokio = { version = "1.47", optional = true, features = ["rt-multi-thread"] } # both arrow versions below are optional and require object_store object_store = { version = "0.12.3", optional = true, features = ["aws", "azure", "gcp", "http"] } comfy-table = { version = "7.1", optional = true } +# used for Float trait in stats computation +num-traits = { version = "0.2", optional = true } # arrow 56 [dependencies.arrow_56] @@ -118,6 +120,7 @@ default-engine-base = [ "arrow-expression", "futures", "need-arrow", + "num-traits", "tokio", ] # the default-engine-native-tls use the reqwest crate with default features which uses native-tls. if you want diff --git a/kernel/examples/write-table/src/main.rs b/kernel/examples/write-table/src/main.rs index f083611c03..7649bba872 100644 --- a/kernel/examples/write-table/src/main.rs +++ b/kernel/examples/write-table/src/main.rs @@ -94,9 +94,9 @@ async fn try_main() -> DeltaResult<()> { .with_data_change(true); // Write the data using the engine - let write_context = Arc::new(txn.get_write_context()); + let write_context = txn.get_write_context()?; let file_metadata = engine - .write_parquet(&sample_data, write_context.as_ref(), HashMap::new()) + .write_parquet(&sample_data, &write_context, HashMap::new()) .await?; // Add the file metadata to the transaction diff --git a/kernel/src/engine/default/mod.rs b/kernel/src/engine/default/mod.rs index 430f10a30e..d0afa86a7b 100644 --- a/kernel/src/engine/default/mod.rs +++ b/kernel/src/engine/default/mod.rs @@ -33,6 +33,7 @@ pub mod file_stream; pub mod filesystem; pub mod json; pub mod parquet; +pub mod stats; pub mod storage; /// Converts a Stream-producing future to a synchronous iterator. @@ -216,7 +217,12 @@ impl DefaultEngine { )?; let physical_data = logical_to_physical_expr.evaluate(data)?; self.parquet - .write_parquet_file(write_context.target_dir(), physical_data, partition_values) + .write_parquet_file( + write_context.target_dir(), + physical_data, + partition_values, + write_context.stats_columns(), + ) .await } } diff --git a/kernel/src/engine/default/parquet.rs b/kernel/src/engine/default/parquet.rs index 7ed225ddb3..f3061fc4df 100644 --- a/kernel/src/engine/default/parquet.rs +++ b/kernel/src/engine/default/parquet.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use delta_kernel_derive::internal_api; use crate::arrow::array::builder::{MapBuilder, MapFieldNames, StringBuilder}; -use crate::arrow::array::{Int64Array, RecordBatch, StringArray, StructArray}; +use crate::arrow::array::{Array, Int64Array, RecordBatch, StringArray, StructArray}; use crate::arrow::datatypes::{DataType, Field}; use crate::parquet::arrow::arrow_reader::{ ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, @@ -23,9 +23,10 @@ use object_store::{DynObjectStore, ObjectStore}; use uuid::Uuid; use super::file_stream::{FileOpenFuture, FileOpener, FileStream}; +use super::stats::StatisticsCollector; use super::UrlExt; use crate::engine::arrow_conversion::{TryFromArrow as _, TryIntoArrow as _}; -use crate::engine::arrow_data::ArrowEngineData; +use crate::engine::arrow_data::{extract_record_batch, ArrowEngineData}; use crate::engine::arrow_utils::{ fixup_parquet_read, generate_mask, get_requested_indices, ordering_needs_row_indexes, RowIndexBuilder, @@ -54,6 +55,8 @@ pub struct DataFileMetadata { file_meta: FileMeta, // NB: We use usize instead of u64 since arrow uses usize for record batch sizes num_records: usize, + /// Collected statistics for this file (optional). + stats: Option, } impl DataFileMetadata { @@ -61,9 +64,16 @@ impl DataFileMetadata { Self { file_meta, num_records, + stats: None, } } + /// Set the collected statistics for this file. + pub fn with_stats(mut self, stats: StructArray) -> Self { + self.stats = Some(stats); + self + } + /// Convert DataFileMetadata into a record batch which matches the schema returned by /// [`add_files_schema`]. /// @@ -81,6 +91,7 @@ impl DataFileMetadata { size, }, num_records, + stats, } = self; // create the record batch of the write metadata let path = Arc::new(StringArray::from(vec![location.to_string()])); @@ -104,20 +115,53 @@ impl DataFileMetadata { .map_err(|_| Error::generic("Failed to convert parquet metadata 'size' to i64"))?; let size = Arc::new(Int64Array::from(vec![size])); let modification_time = Arc::new(Int64Array::from(vec![*last_modified])); - let stats = Arc::new(StructArray::try_new_with_length( - vec![Field::new("numRecords", DataType::Int64, true)].into(), - vec![Arc::new(Int64Array::from(vec![*num_records as i64]))], - None, - 1, - )?); - Ok(Box::new(ArrowEngineData::new(RecordBatch::try_new( - Arc::new( - crate::transaction::BASE_ADD_FILES_SCHEMA - .as_ref() - .try_into_arrow()?, + // Use full stats if available, otherwise just numRecords + let stats_array: Arc = if let Some(full_stats) = stats { + Arc::new(full_stats.clone()) + } else { + Arc::new(StructArray::try_new_with_length( + vec![Field::new("numRecords", DataType::Int64, true)].into(), + vec![Arc::new(Int64Array::from(vec![*num_records as i64]))], + None, + 1, + )?) + }; + + // Build schema dynamically based on stats + let stats_field = Field::new("stats", stats_array.data_type().clone(), true); + let schema = crate::arrow::datatypes::Schema::new(vec![ + Field::new("path", crate::arrow::datatypes::DataType::Utf8, false), + Field::new( + "partitionValues", + crate::arrow::datatypes::DataType::Map( + Arc::new(Field::new( + "key_value", + crate::arrow::datatypes::DataType::Struct( + vec![ + Field::new("key", crate::arrow::datatypes::DataType::Utf8, false), + Field::new("value", crate::arrow::datatypes::DataType::Utf8, true), + ] + .into(), + ), + false, + )), + false, + ), + false, + ), + Field::new("size", crate::arrow::datatypes::DataType::Int64, false), + Field::new( + "modificationTime", + crate::arrow::datatypes::DataType::Int64, + false, ), - vec![path, partitions, size, modification_time, stats], + stats_field, + ]); + + Ok(Box::new(ArrowEngineData::new(RecordBatch::try_new( + Arc::new(schema), + vec![path, partitions, size, modification_time, stats_array], )?))) } } @@ -201,8 +245,22 @@ impl DefaultParquetHandler { path: &url::Url, data: Box, partition_values: HashMap, + stats_columns: &[String], ) -> DeltaResult> { - let parquet_metadata = self.write_parquet(path, data).await?; + // Collect statistics from the data during write + let record_batch = extract_record_batch(data.as_ref())?; + + // Initialize stats collector and update with this batch + let mut stats_collector = StatisticsCollector::new(record_batch.schema(), stats_columns); + stats_collector.update(record_batch, None)?; // No mask for new file writes + let stats = stats_collector.finalize()?; + + // Write the parquet file + let mut parquet_metadata = self.write_parquet(path, data).await?; + + // Attach the collected statistics + parquet_metadata = parquet_metadata.with_stats(stats); + parquet_metadata.as_record_batch(&partition_values) } } @@ -294,6 +352,7 @@ impl ParquetHandler for DefaultParquetHandler { /// - `location` - The full URL path where the Parquet file should be written /// (e.g., `s3://bucket/path/file.parquet`, `file:///path/to/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Column names for which statistics should be collected. /// /// # Returns /// @@ -302,6 +361,7 @@ impl ParquetHandler for DefaultParquetHandler { &self, location: url::Url, mut data: Box>> + Send>, + _stats_columns: &[String], ) -> DeltaResult<()> { let store = self.store.clone(); @@ -682,6 +742,7 @@ mod tests { size, }, num_records, + .. } = write_metadata; let expected_location = Url::parse("memory:///data/").unwrap(); @@ -776,7 +837,7 @@ mod tests { // Test writing through the trait method let file_url = Url::parse("memory:///test/data.parquet").unwrap(); parquet_handler - .write_parquet_file(file_url.clone(), data_iter) + .write_parquet_file(file_url.clone(), data_iter, &[]) .unwrap(); // Verify we can read the file back @@ -964,7 +1025,7 @@ mod tests { // Write the data let file_url = Url::parse("memory:///roundtrip/test.parquet").unwrap(); parquet_handler - .write_parquet_file(file_url.clone(), data_iter) + .write_parquet_file(file_url.clone(), data_iter, &[]) .unwrap(); // Read it back @@ -1152,7 +1213,7 @@ mod tests { // Write the first file parquet_handler - .write_parquet_file(file_url.clone(), data_iter1) + .write_parquet_file(file_url.clone(), data_iter1, &[]) .unwrap(); // Create second data set with different data @@ -1168,7 +1229,7 @@ mod tests { // Overwrite with second file (overwrite=true) parquet_handler - .write_parquet_file(file_url.clone(), data_iter2) + .write_parquet_file(file_url.clone(), data_iter2, &[]) .unwrap(); // Read back and verify it contains the second data set @@ -1231,7 +1292,7 @@ mod tests { // Write the first file parquet_handler - .write_parquet_file(file_url.clone(), data_iter1) + .write_parquet_file(file_url.clone(), data_iter1, &[]) .unwrap(); // Create second data set @@ -1247,7 +1308,7 @@ mod tests { // Write again - should overwrite successfully (new behavior always overwrites) parquet_handler - .write_parquet_file(file_url.clone(), data_iter2) + .write_parquet_file(file_url.clone(), data_iter2, &[]) .unwrap(); // Verify the file was overwritten with the new data diff --git a/kernel/src/engine/default/stats.rs b/kernel/src/engine/default/stats.rs new file mode 100644 index 0000000000..4110ad26a1 --- /dev/null +++ b/kernel/src/engine/default/stats.rs @@ -0,0 +1,1299 @@ +//! Statistics collection for Delta Lake file writes. +//! +//! This module provides `StatisticsCollector` which accumulates statistics +//! across multiple Arrow RecordBatches during file writes. + +use std::collections::HashSet; +use std::sync::Arc; + +use crate::arrow::array::{ + new_null_array, Array, ArrayRef, BooleanArray, Int64Array, LargeStringArray, PrimitiveArray, + RecordBatch, StringArray, StringViewArray, StructArray, +}; +use crate::arrow::buffer::NullBuffer; +use crate::arrow::datatypes::{ + ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, Field, Fields, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use crate::{DeltaResult, Error}; + +// ============================================================================ +// Trait-based min/max aggregation +// ============================================================================ + +/// Trait for min/max aggregation operations. +/// Implementations define whether we're computing min or max. +trait MinMaxAggregator { + /// Aggregate two values, returning the min or max. + fn aggregate(a: T, b: T) -> T; + + /// Aggregate two floats with proper NaN handling. + fn aggregate_float(a: T, b: T) -> T; +} + +/// Min aggregation - returns the smaller value. +struct MinAgg; + +impl MinMaxAggregator for MinAgg { + fn aggregate(a: T, b: T) -> T { + if a <= b { + a + } else { + b + } + } + + fn aggregate_float(a: T, b: T) -> T { + a.min(b) + } +} + +/// Max aggregation - returns the larger value. +struct MaxAgg; + +impl MinMaxAggregator for MaxAgg { + fn aggregate(a: T, b: T) -> T { + if a >= b { + a + } else { + b + } + } + + fn aggregate_float(a: T, b: T) -> T { + a.max(b) + } +} + +/// Compute aggregation for primitive array types, optionally filtered by mask. +fn aggregate_primitive( + array: &PrimitiveArray, + mask: Option<&NullBuffer>, +) -> Option +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, + Agg: MinMaxAggregator, +{ + if let Some(m) = mask { + array + .iter() + .enumerate() + .filter_map(|(i, opt_val)| if m.is_valid(i) { opt_val } else { None }) + .reduce(|acc, val| Agg::aggregate(acc, val)) + } else { + array + .iter() + .flatten() + .reduce(|acc, val| Agg::aggregate(acc, val)) + } +} + +/// Helper to downcast an array reference to a specific type. +fn downcast(column: &ArrayRef) -> DeltaResult<&T> { + column.as_any().downcast_ref::().ok_or_else(|| { + Error::generic(format!( + "Failed to downcast from {} to {}", + std::any::type_name_of_val(column.as_ref()), + std::any::type_name::(), + )) + }) +} + +/// Compute aggregation for a column, returning a single-element array. +/// If mask is provided, only masked-in rows are considered. +fn compute_agg( + column: &ArrayRef, + mask: Option<&NullBuffer>, +) -> DeltaResult { + match column.data_type() { + DataType::Int8 => agg_primitive::(column, mask), + DataType::Int16 => agg_primitive::(column, mask), + DataType::Int32 => agg_primitive::(column, mask), + DataType::Int64 => agg_primitive::(column, mask), + DataType::UInt8 => agg_primitive::(column, mask), + DataType::UInt16 => agg_primitive::(column, mask), + DataType::UInt32 => agg_primitive::(column, mask), + DataType::UInt64 => agg_primitive::(column, mask), + DataType::Float32 => agg_float::(column, mask), + DataType::Float64 => agg_float::(column, mask), + DataType::Date32 => agg_primitive::(column, mask), + DataType::Date64 => agg_primitive::(column, mask), + DataType::Timestamp(TimeUnit::Second, tz) => { + agg_timestamp::(column, mask, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + agg_timestamp::(column, mask, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + agg_timestamp::(column, mask, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + agg_timestamp::(column, mask, tz.clone()) + } + DataType::Decimal128(p, s) => agg_decimal128::(column, mask, *p, *s), + DataType::Utf8 => agg_string::(column, mask), + DataType::LargeUtf8 => agg_large_string::(column, mask), + DataType::Utf8View => agg_string_view::(column, mask), + // Types without meaningful min/max stats + _ => Ok(new_null_array(column.data_type(), 1)), + } +} + +/// Aggregate primitive types (integers, dates). +fn agg_primitive(column: &ArrayRef, mask: Option<&NullBuffer>) -> DeltaResult +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, + Agg: MinMaxAggregator, + PrimitiveArray: From>>, +{ + let array = downcast::>(column)?; + let result = aggregate_primitive::(array, mask); + Ok(Arc::new(PrimitiveArray::::from(vec![result]))) +} + +/// Aggregate float types with NaN handling. +fn agg_float(column: &ArrayRef, mask: Option<&NullBuffer>) -> DeltaResult +where + T: ArrowPrimitiveType, + T::Native: num_traits::Float, + Agg: MinMaxAggregator, + PrimitiveArray: From>>, +{ + let array = downcast::>(column)?; + let result = if let Some(m) = mask { + array + .iter() + .enumerate() + .filter_map(|(i, opt_val)| if m.is_valid(i) { opt_val } else { None }) + .reduce(|acc, val| Agg::aggregate_float(acc, val)) + } else { + array + .iter() + .flatten() + .reduce(|acc, val| Agg::aggregate_float(acc, val)) + }; + Ok(Arc::new(PrimitiveArray::::from(vec![result]))) +} + +/// Aggregate timestamp types, preserving timezone. +fn agg_timestamp( + column: &ArrayRef, + mask: Option<&NullBuffer>, + tz: Option>, +) -> DeltaResult +where + T: crate::arrow::datatypes::ArrowTimestampType, + Agg: MinMaxAggregator, + PrimitiveArray: From>>, +{ + let array = downcast::>(column)?; + let result = aggregate_primitive::(array, mask); + Ok(Arc::new( + PrimitiveArray::::from(vec![result]).with_timezone_opt(tz), + )) +} + +/// Aggregate Decimal128 types, preserving precision and scale. +fn agg_decimal128( + column: &ArrayRef, + mask: Option<&NullBuffer>, + precision: u8, + scale: i8, +) -> DeltaResult { + use crate::arrow::array::Decimal128Array; + let array = downcast::(column)?; + let result = aggregate_primitive::(array, mask); + let arr = Decimal128Array::from(vec![result]) + .with_precision_and_scale(precision, scale) + .map_err(|e| Error::generic(format!("Invalid decimal precision/scale: {e}")))?; + Ok(Arc::new(arr)) +} + +/// Aggregate Utf8 strings with truncation. +fn agg_string( + column: &ArrayRef, + mask: Option<&NullBuffer>, +) -> DeltaResult { + let array = downcast::(column)?; + let result = if let Some(m) = mask { + array + .iter() + .enumerate() + .filter_map(|(i, opt_val)| if m.is_valid(i) { opt_val } else { None }) + .reduce(|acc, val| Agg::aggregate(acc, val)) + .map(truncate_string) + } else { + array + .iter() + .flatten() + .reduce(|acc, val| Agg::aggregate(acc, val)) + .map(truncate_string) + }; + Ok(Arc::new(StringArray::from(vec![result]))) +} + +/// Aggregate LargeUtf8 strings with truncation. +fn agg_large_string( + column: &ArrayRef, + mask: Option<&NullBuffer>, +) -> DeltaResult { + let array = downcast::(column)?; + let result = if let Some(m) = mask { + array + .iter() + .enumerate() + .filter_map(|(i, opt_val)| if m.is_valid(i) { opt_val } else { None }) + .reduce(|acc, val| Agg::aggregate(acc, val)) + .map(truncate_string) + } else { + array + .iter() + .flatten() + .reduce(|acc, val| Agg::aggregate(acc, val)) + .map(truncate_string) + }; + Ok(Arc::new(LargeStringArray::from(vec![result]))) +} + +/// Aggregate StringView with truncation. +fn agg_string_view( + column: &ArrayRef, + mask: Option<&NullBuffer>, +) -> DeltaResult { + let array = downcast::(column)?; + let result: Option = if let Some(m) = mask { + array + .iter() + .enumerate() + .filter_map(|(i, opt_val)| if m.is_valid(i) { opt_val } else { None }) + .map(|s| s.to_string()) + .reduce(|acc, val| { + if Agg::aggregate(acc.as_str(), val.as_str()) == acc.as_str() { + acc + } else { + val + } + }) + .map(|s| truncate_string(&s)) + } else { + array + .iter() + .flatten() + .map(|s| s.to_string()) + .reduce(|acc, val| { + if Agg::aggregate(acc.as_str(), val.as_str()) == acc.as_str() { + acc + } else { + val + } + }) + .map(|s| truncate_string(&s)) + }; + Ok(Arc::new(StringViewArray::from(vec![result]))) +} + +/// Truncate a string to the max prefix length for stats. +fn truncate_string(s: &str) -> String { + s.chars().take(STRING_PREFIX_LENGTH).collect() +} + +/// Compare two single-element arrays and select min or max. +fn compare_and_select(a: &ArrayRef, b: &ArrayRef) -> DeltaResult { + match a.data_type() { + DataType::Int8 => compare_primitive::(a, b), + DataType::Int16 => compare_primitive::(a, b), + DataType::Int32 => compare_primitive::(a, b), + DataType::Int64 => compare_primitive::(a, b), + DataType::UInt8 => compare_primitive::(a, b), + DataType::UInt16 => compare_primitive::(a, b), + DataType::UInt32 => compare_primitive::(a, b), + DataType::UInt64 => compare_primitive::(a, b), + DataType::Date32 => compare_primitive::(a, b), + DataType::Date64 => compare_primitive::(a, b), + DataType::Float32 => compare_float::(a, b), + DataType::Float64 => compare_float::(a, b), + DataType::Utf8 => compare_string::(a, b), + _ => Ok(a.clone()), // Fallback + } +} + +fn compare_primitive(a: &ArrayRef, b: &ArrayRef) -> DeltaResult +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, + Agg: MinMaxAggregator, + PrimitiveArray: From>, +{ + let a_arr = downcast::>(a)?; + let b_arr = downcast::>(b)?; + let result = Agg::aggregate(a_arr.value(0), b_arr.value(0)); + Ok(Arc::new(PrimitiveArray::::from(vec![result]))) +} + +fn compare_float(a: &ArrayRef, b: &ArrayRef) -> DeltaResult +where + T: ArrowPrimitiveType, + T::Native: num_traits::Float, + Agg: MinMaxAggregator, + PrimitiveArray: From>, +{ + let a_arr = downcast::>(a)?; + let b_arr = downcast::>(b)?; + let result = Agg::aggregate_float(a_arr.value(0), b_arr.value(0)); + Ok(Arc::new(PrimitiveArray::::from(vec![result]))) +} + +fn compare_string(a: &ArrayRef, b: &ArrayRef) -> DeltaResult { + let a_arr = downcast::(a)?; + let b_arr = downcast::(b)?; + let result = Agg::aggregate(a_arr.value(0), b_arr.value(0)); + Ok(Arc::new(StringArray::from(vec![result]))) +} + +/// Maximum prefix length for string statistics (Delta protocol requirement). +const STRING_PREFIX_LENGTH: usize = 32; + +/// Collects statistics from RecordBatches for Delta Lake file statistics. +/// Supports streaming accumulation across multiple batches. +#[allow(dead_code)] +pub(crate) struct StatisticsCollector { + /// Total number of records across all batches. + num_records: i64, + /// Column names from the data schema. + column_names: Vec, + /// Column names that should have stats collected. + stats_columns: HashSet, + /// Null counts per column. + null_counts: Vec, + /// Min values per column (single-element arrays). + min_values: Vec, + /// Max values per column (single-element arrays). + max_values: Vec, +} + +#[allow(dead_code)] +impl StatisticsCollector { + /// Create a new statistics collector. + /// + /// # Arguments + /// * `data_schema` - The Arrow schema of the data being written + /// * `stats_columns` - Column names that should have statistics collected + pub(crate) fn new( + data_schema: Arc, + stats_columns: &[String], + ) -> DeltaResult { + let stats_set: HashSet = stats_columns.iter().cloned().collect(); + + let mut column_names = Vec::with_capacity(data_schema.fields().len()); + let mut null_counts = Vec::with_capacity(data_schema.fields().len()); + let mut min_values = Vec::with_capacity(data_schema.fields().len()); + let mut max_values = Vec::with_capacity(data_schema.fields().len()); + + for field in data_schema.fields() { + column_names.push(field.name().clone()); + null_counts.push(Self::create_zero_null_count(field.data_type())?); + let null_array = Self::create_null_array(field.data_type()); + min_values.push(null_array.clone()); + max_values.push(null_array); + } + + Ok(Self { + num_records: 0, + column_names, + stats_columns: stats_set, + null_counts, + min_values, + max_values, + }) + } + + /// Check if a column should have stats collected. + fn should_collect_stats(&self, column_name: &str) -> bool { + self.stats_columns.contains(column_name) + } + + /// Create a zero-initialized null count structure for the given data type. + fn create_zero_null_count(data_type: &DataType) -> DeltaResult { + match data_type { + DataType::Struct(fields) => { + let children: Vec = fields + .iter() + .map(|f| Self::create_zero_null_count(f.data_type())) + .collect::>()?; + let null_count_fields: Fields = fields + .iter() + .map(|f| { + let child_type = Self::null_count_data_type(f.data_type()); + Field::new(f.name(), child_type, true) + }) + .collect(); + Ok(Arc::new( + StructArray::try_new(null_count_fields, children, None).map_err(|e| { + Error::generic(format!("Failed to create null count struct: {e}")) + })?, + )) + } + _ => Ok(Arc::new(Int64Array::from(vec![0i64]))), + } + } + + /// Get the data type for null counts of a given data type. + fn null_count_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Struct(fields) => { + let null_count_fields: Vec = fields + .iter() + .map(|f| Field::new(f.name(), Self::null_count_data_type(f.data_type()), true)) + .collect(); + DataType::Struct(null_count_fields.into()) + } + _ => DataType::Int64, + } + } + + /// Compute null counts for a column, respecting the optional mask. + fn compute_null_counts(column: &ArrayRef, mask: Option<&NullBuffer>) -> DeltaResult { + match column.data_type() { + DataType::Struct(fields) => { + let struct_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::generic("Expected StructArray for struct column"))?; + let children: Vec = (0..fields.len()) + .map(|i| Self::compute_null_counts(struct_array.column(i), mask)) + .collect::>>()?; + let null_count_fields: Fields = fields + .iter() + .map(|f| Field::new(f.name(), Self::null_count_data_type(f.data_type()), true)) + .collect(); + Ok(Arc::new( + StructArray::try_new(null_count_fields, children, None) + .map_err(|e| Error::generic(format!("null count struct: {e}")))?, + )) + } + _ => { + let null_count = match mask { + Some(m) => { + // Count nulls only for masked-in rows + (0..column.len()) + .filter(|&i| m.is_valid(i) && column.is_null(i)) + .count() as i64 + } + None => column.null_count() as i64, + }; + Ok(Arc::new(Int64Array::from(vec![null_count]))) + } + } + } + + /// Merge two null count structures by adding them together. + fn merge_null_counts(existing: &ArrayRef, new: &ArrayRef) -> DeltaResult { + match existing.data_type() { + DataType::Struct(fields) => { + let existing_struct = + existing + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::generic("Expected StructArray for existing null count") + })?; + let new_struct = new + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::generic("Expected StructArray for new null count"))?; + + let children: Vec = (0..fields.len()) + .map(|i| { + Self::merge_null_counts(existing_struct.column(i), new_struct.column(i)) + }) + .collect::>()?; + + let null_count_fields: Fields = fields + .iter() + .map(|f| Field::new(f.name(), Self::null_count_data_type(f.data_type()), true)) + .collect(); + Ok(Arc::new( + StructArray::try_new(null_count_fields, children, None).map_err(|e| { + Error::generic(format!("Failed to merge null count struct: {e}")) + })?, + )) + } + _ => { + let existing_val = existing + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::generic("Expected Int64Array for existing null count"))? + .value(0); + let new_val = new + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::generic("Expected Int64Array for new null count"))? + .value(0); + Ok(Arc::new(Int64Array::from(vec![existing_val + new_val]))) + } + } + } + + /// Create a null array of the given type with length 1. + fn create_null_array(data_type: &DataType) -> ArrayRef { + new_null_array(data_type, 1) + } + + /// Compute min/max for a column, returning (min, max) as single-element arrays. + /// If mask is provided, only masked-in rows are considered. + fn compute_min_max( + column: &ArrayRef, + mask: Option<&NullBuffer>, + ) -> DeltaResult<(ArrayRef, ArrayRef)> { + let min_val = compute_agg::(column, mask)?; + let max_val = compute_agg::(column, mask)?; + Ok((min_val, max_val)) + } + + /// Merge min values, keeping the smaller one. + fn merge_min(existing: &ArrayRef, new: &ArrayRef) -> DeltaResult { + if existing.is_null(0) { + return Ok(new.clone()); + } + if new.is_null(0) { + return Ok(existing.clone()); + } + let new_min = compute_agg::(new, None)?; + let existing_min = compute_agg::(existing, None)?; + compare_and_select::(&existing_min, &new_min) + } + + /// Merge max values, keeping the larger one. + fn merge_max(existing: &ArrayRef, new: &ArrayRef) -> DeltaResult { + if existing.is_null(0) { + return Ok(new.clone()); + } + if new.is_null(0) { + return Ok(existing.clone()); + } + let new_max = compute_agg::(new, None)?; + let existing_max = compute_agg::(existing, None)?; + compare_and_select::(&existing_max, &new_max) + } + + /// Update statistics with data from a RecordBatch. + /// + /// # Arguments + /// * `batch` - The RecordBatch to accumulate statistics from + /// * `mask` - Optional mask indicating which rows to include (true = include) + /// Used for deletion vector support where masked-out rows should not + /// contribute to statistics. + pub(crate) fn update( + &mut self, + batch: &RecordBatch, + mask: Option<&NullBuffer>, + ) -> DeltaResult<()> { + // Count rows, respecting mask if present + let row_count = match mask { + Some(m) => m.iter().filter(|&valid| valid).count() as i64, + None => batch.num_rows() as i64, + }; + self.num_records += row_count; + + for (col_idx, column) in batch.columns().iter().enumerate() { + let col_name = &self.column_names[col_idx]; + if self.should_collect_stats(col_name) { + // Update null counts + let batch_null_counts = Self::compute_null_counts(column, mask)?; + self.null_counts[col_idx] = + Self::merge_null_counts(&self.null_counts[col_idx], &batch_null_counts)?; + + // Update min/max + let (batch_min, batch_max) = Self::compute_min_max(column, mask)?; + self.min_values[col_idx] = Self::merge_min(&self.min_values[col_idx], &batch_min)?; + self.max_values[col_idx] = Self::merge_max(&self.max_values[col_idx], &batch_max)?; + } + } + + Ok(()) + } + + /// Finalize and return the collected statistics as a StructArray. + pub(crate) fn finalize(&self) -> DeltaResult { + let mut fields = Vec::new(); + let mut arrays: Vec> = Vec::new(); + + // numRecords + fields.push(Field::new("numRecords", DataType::Int64, true)); + arrays.push(Arc::new(Int64Array::from(vec![self.num_records]))); + + // nullCount - nested struct matching data schema + let null_count_fields: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(_, name)| self.should_collect_stats(name)) + .map(|(idx, name)| Field::new(name, self.null_counts[idx].data_type().clone(), true)) + .collect(); + + if !null_count_fields.is_empty() { + let null_count_arrays: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(_, name)| self.should_collect_stats(name)) + .map(|(idx, _)| self.null_counts[idx].clone()) + .collect(); + + let null_count_struct = + StructArray::try_new(null_count_fields.into(), null_count_arrays, None) + .map_err(|e| Error::generic(format!("Failed to create nullCount: {e}")))?; + + fields.push(Field::new( + "nullCount", + null_count_struct.data_type().clone(), + true, + )); + arrays.push(Arc::new(null_count_struct)); + } + + // minValues - nested struct with min values + let min_fields: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(idx, name)| { + self.should_collect_stats(name) && !self.min_values[*idx].is_null(0) + }) + .map(|(idx, name)| Field::new(name, self.min_values[idx].data_type().clone(), true)) + .collect(); + + if !min_fields.is_empty() { + let min_arrays: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(idx, name)| { + self.should_collect_stats(name) && !self.min_values[*idx].is_null(0) + }) + .map(|(idx, _)| self.min_values[idx].clone()) + .collect(); + + let min_struct = StructArray::try_new(min_fields.into(), min_arrays, None) + .map_err(|e| Error::generic(format!("Failed to create minValues: {e}")))?; + + fields.push(Field::new( + "minValues", + min_struct.data_type().clone(), + true, + )); + arrays.push(Arc::new(min_struct)); + } + + // maxValues - nested struct with max values + let max_fields: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(idx, name)| { + self.should_collect_stats(name) && !self.max_values[*idx].is_null(0) + }) + .map(|(idx, name)| Field::new(name, self.max_values[idx].data_type().clone(), true)) + .collect(); + + if !max_fields.is_empty() { + let max_arrays: Vec = self + .column_names + .iter() + .enumerate() + .filter(|(idx, name)| { + self.should_collect_stats(name) && !self.max_values[*idx].is_null(0) + }) + .map(|(idx, _)| self.max_values[idx].clone()) + .collect(); + + let max_struct = StructArray::try_new(max_fields.into(), max_arrays, None) + .map_err(|e| Error::generic(format!("Failed to create maxValues: {e}")))?; + + fields.push(Field::new( + "maxValues", + max_struct.data_type().clone(), + true, + )); + arrays.push(Arc::new(max_struct)); + } + + // tightBounds + fields.push(Field::new("tightBounds", DataType::Boolean, true)); + arrays.push(Arc::new(BooleanArray::from(vec![true]))); + + StructArray::try_new(fields.into(), arrays, None) + .map_err(|e| Error::generic(format!("Failed to create stats struct: {e}"))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::array::{Array, Int64Array, StringArray}; + use crate::arrow::datatypes::Schema; + + #[test] + fn test_statistics_collector_single_batch() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let mut collector = StatisticsCollector::new(schema, &["id".to_string()]).unwrap(); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + assert_eq!(stats.len(), 1); + let num_records = stats + .column_by_name("numRecords") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(num_records.value(0), 3); + } + + #[test] + fn test_statistics_collector_null_counts() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + + let mut collector = + StatisticsCollector::new(schema, &["id".to_string(), "value".to_string()]).unwrap(); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + // Check nullCount struct + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // id has 0 nulls + let id_null_count = null_count + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id_null_count.value(0), 0); + + // value has 1 null + let value_null_count = null_count + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_null_count.value(0), 1); + } + + #[test] + fn test_statistics_collector_multiple_batches_null_counts() { + let schema = Arc::new(Schema::new(vec![Field::new("value", DataType::Utf8, true)])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec![Some("a"), None]))], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec![None, None, Some("b")]))], + ) + .unwrap(); + + let mut collector = StatisticsCollector::new(schema, &["value".to_string()]).unwrap(); + collector.update(&batch1, None).unwrap(); + collector.update(&batch2, None).unwrap(); + let stats = collector.finalize().unwrap(); + + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let value_null_count = null_count + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // 1 null in batch1 + 2 nulls in batch2 = 3 total + assert_eq!(value_null_count.value(0), 3); + } + + #[test] + fn test_statistics_collector_respects_stats_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + + // Only collect stats for "id", not "value" + let mut collector = StatisticsCollector::new(schema, &["id".to_string()]).unwrap(); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Only id should be present + assert!(null_count.column_by_name("id").is_some()); + assert!(null_count.column_by_name("value").is_none()); + } + + #[test] + fn test_statistics_collector_min_max() { + let schema = Arc::new(Schema::new(vec![ + Field::new("number", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![5, 1, 9, 3])), + Arc::new(StringArray::from(vec![ + Some("banana"), + Some("apple"), + Some("cherry"), + None, + ])), + ], + ) + .unwrap(); + + let mut collector = + StatisticsCollector::new(schema, &["number".to_string(), "name".to_string()]).unwrap(); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + // Check minValues + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let number_min = min_values + .column_by_name("number") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(number_min.value(0), 1); + + let name_min = min_values + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_min.value(0), "apple"); + + // Check maxValues + let max_values = stats + .column_by_name("maxValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let number_max = max_values + .column_by_name("number") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(number_max.value(0), 9); + + let name_max = max_values + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_max.value(0), "cherry"); + } + + #[test] + fn test_statistics_collector_min_max_multiple_batches() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![5, 10, 3]))], + ) + .unwrap(); + + let batch2 = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![1, 8]))]) + .unwrap(); + + let mut collector = StatisticsCollector::new(schema, &["value".to_string()]).unwrap(); + collector.update(&batch1, None).unwrap(); + collector.update(&batch2, None).unwrap(); + let stats = collector.finalize().unwrap(); + + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_min = min_values + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_min.value(0), 1); // min across both batches + + let max_values = stats + .column_by_name("maxValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_max = max_values + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_max.value(0), 10); // max across both batches + } + + #[test] + fn test_statistics_collector_with_mask() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + // Batch with values [1, 2, 3, 4, 5] + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ]))], + ) + .unwrap(); + + // Mask: only include rows 1, 3 (values 2, 4) + let mask = NullBuffer::from(vec![false, true, false, true, false]); + + let mut collector = StatisticsCollector::new(schema, &["value".to_string()]).unwrap(); + collector.update(&batch, Some(&mask)).unwrap(); + let stats = collector.finalize().unwrap(); + + // numRecords should be 2 (only masked-in rows) + let num_records = stats + .column_by_name("numRecords") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(num_records.value(0), 2); + + // min should be 2, max should be 4 + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_min = min_values + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_min.value(0), 2); + + let max_values = stats + .column_by_name("maxValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_max = max_values + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_max.value(0), 4); + } + + #[test] + fn test_statistics_collector_with_mask_null_count() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + // Batch with values [1, null, 3, null, 5] + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![ + Some(1), + None, + Some(3), + None, + Some(5), + ]))], + ) + .unwrap(); + + // Mask: include rows 0, 1, 2 (values 1, null, 3) + let mask = NullBuffer::from(vec![true, true, true, false, false]); + + let mut collector = StatisticsCollector::new(schema, &["value".to_string()]).unwrap(); + collector.update(&batch, Some(&mask)).unwrap(); + let stats = collector.finalize().unwrap(); + + // nullCount should be 1 (only the null at index 1 is in masked rows) + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_null_count = null_count + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_null_count.value(0), 1); + } +} + +/// Verifies that collected statistics match the expected schema. +/// Used for debugging and testing stats collection. +pub(crate) struct StatsVerifier; + +impl StatsVerifier { + /// Verify stats and return a structured result. + #[allow(unused)] + pub(crate) fn verify( + stats: &StructArray, + expected_columns: &[String], + ) -> DeltaResult { + use crate::arrow::array::Array; + + let fields = stats.fields(); + let field_names: Vec<&str> = fields.iter().map(|f| f.name().as_str()).collect(); + + // Check numRecords + let num_records = if field_names.contains(&"numRecords") { + let num_records_idx = fields + .iter() + .position(|f| f.name() == "numRecords") + .unwrap(); + let num_records_array = stats.column(num_records_idx); + if let Some(int_array) = num_records_array.as_any().downcast_ref::() { + int_array.value(0) + } else { + 0 + } + } else { + 0 + }; + + // Check tightBounds + let tight_bounds = if field_names.contains(&"tightBounds") { + let idx = fields + .iter() + .position(|f| f.name() == "tightBounds") + .unwrap(); + let array = stats.column(idx); + if let Some(bool_array) = array.as_any().downcast_ref::() { + bool_array.value(0) + } else { + false + } + } else { + false + }; + + // Check nullCount columns + let mut present_null_count = Vec::new(); + let mut missing_null_count = Vec::new(); + + if let Some(idx) = fields.iter().position(|f| f.name() == "nullCount") { + let null_count_array = stats.column(idx); + if let Some(null_struct) = null_count_array.as_any().downcast_ref::() { + let null_fields: Vec<&str> = null_struct + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + for col in expected_columns { + if null_fields.contains(&col.as_str()) { + present_null_count.push(col.clone()); + } else { + missing_null_count.push(col.clone()); + } + } + } + } + + // Get min/max columns + let mut min_max_columns = Vec::new(); + if let Some(idx) = fields.iter().position(|f| f.name() == "minValues") { + let min_array = stats.column(idx); + if let Some(min_struct) = min_array.as_any().downcast_ref::() { + for field in min_struct.fields() { + min_max_columns.push(field.name().clone()); + } + } + } + + Ok(StatsVerificationResult { + num_records, + tight_bounds, + present_null_count_columns: present_null_count, + missing_null_count_columns: missing_null_count, + min_max_columns, + }) + } + + /// Verify stats and return a detailed human-readable string. + #[allow(unused)] + pub(crate) fn verify_detailed( + stats: &StructArray, + expected_columns: &[String], + ) -> DeltaResult { + let result = Self::verify(stats, expected_columns)?; + Ok(format!( + "Stats: numRecords={}, tightBounds={}, nullCount=[{}], minMax=[{}]", + result.num_records, + result.tight_bounds, + result.present_null_count_columns.join(", "), + result.min_max_columns.join(", ") + )) + } +} + +/// Result of stats verification. +#[allow(unused)] +pub(crate) struct StatsVerificationResult { + pub num_records: i64, + pub tight_bounds: bool, + pub present_null_count_columns: Vec, + pub missing_null_count_columns: Vec, + pub min_max_columns: Vec, +} + +impl StatsVerificationResult { + /// Returns true if all expected columns have nullCount stats. + pub fn has_all_null_counts(&self) -> bool { + self.missing_null_count_columns.is_empty() + } +} + +#[cfg(test)] +mod verifier_tests { + use super::*; + use crate::arrow::datatypes::Schema; + + #[test] + fn test_stats_verifier_valid_stats() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + + let mut collector = + StatisticsCollector::new(schema, &["id".to_string(), "value".to_string()]); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + let result = + StatsVerifier::verify(&stats, &["id".to_string(), "value".to_string()]).unwrap(); + + assert_eq!(result.num_records, 3); + assert!(result.tight_bounds); + assert!(result.has_all_null_counts()); + assert_eq!(result.present_null_count_columns.len(), 2); + } + + #[test] + fn test_stats_verifier_detailed_output() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let mut collector = StatisticsCollector::new(schema, &["id".to_string()]); + collector.update(&batch, None).unwrap(); + let stats = collector.finalize().unwrap(); + + let detailed = StatsVerifier::verify_detailed(&stats, &["id".to_string()]).unwrap(); + assert!(detailed.contains("numRecords=3")); + assert!(detailed.contains("tightBounds=true")); + } +} diff --git a/kernel/src/engine/sync/parquet.rs b/kernel/src/engine/sync/parquet.rs index 557ae9b226..12749a7065 100644 --- a/kernel/src/engine/sync/parquet.rs +++ b/kernel/src/engine/sync/parquet.rs @@ -79,6 +79,7 @@ impl ParquetHandler for SyncParquetHandler { /// - `location` - The full URL path where the Parquet file should be written /// (e.g., `file:///path/to/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Column names for which statistics should be collected. /// /// # Returns /// @@ -87,6 +88,7 @@ impl ParquetHandler for SyncParquetHandler { &self, location: Url, mut data: Box>> + Send>, + _stats_columns: &[String], ) -> DeltaResult<()> { // Convert URL to file path let path = location @@ -115,6 +117,7 @@ impl ParquetHandler for SyncParquetHandler { writer.close()?; // writer must be closed to write footer + // TODO: Implement stats collection for SyncEngine Ok(()) } @@ -174,7 +177,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data))); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, &[]) + .unwrap(); // Verify the file exists assert!(file_path.exists()); @@ -295,7 +300,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data))); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, &[]) + .unwrap(); // Verify the file exists assert!(file_path.exists()); @@ -370,7 +377,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data1))); // Write the first file - handler.write_parquet_file(url.clone(), data_iter1).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter1, &[]) + .unwrap(); assert!(file_path.exists()); // Create second data set with different data @@ -386,7 +395,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data2))); // Overwrite with second file (overwrite=true) - handler.write_parquet_file(url.clone(), data_iter2).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter2, &[]) + .unwrap(); // Read back and verify it contains the second data set let file = File::open(&file_path).unwrap(); @@ -445,7 +456,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data1))); // Write the first file - handler.write_parquet_file(url.clone(), data_iter1).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter1, &[]) + .unwrap(); assert!(file_path.exists()); // Create second data set @@ -461,7 +474,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data2))); // Write again - should overwrite successfully (new behavior always overwrites) - handler.write_parquet_file(url.clone(), data_iter2).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter2, &[]) + .unwrap(); // Verify the file was overwritten with the new data let file = File::open(&file_path).unwrap(); @@ -537,7 +552,9 @@ mod tests { > = Box::new(batches.into_iter()); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, &[]) + .unwrap(); // Verify the file exists assert!(file_path.exists()); diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index d27a1c915a..9b3eefe78a 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -778,9 +778,10 @@ pub trait ParquetHandler: AsAny { predicate: Option, ) -> DeltaResult; - /// Write data to a Parquet file at the specified URL. + /// Write data to a Parquet file at the specified URL, collecting statistics. /// - /// This method writes the provided `data` to a Parquet file at the given `url`. + /// This method writes the provided `data` to a Parquet file at the given `url`, + /// and collects statistics (min, max, null count) for the specified columns. /// /// This will overwrite the file if it already exists. /// @@ -789,6 +790,7 @@ pub trait ParquetHandler: AsAny { /// - `url` - The full URL path where the Parquet file should be written /// (e.g., `s3://bucket/path/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Column names for which statistics should be collected. /// /// # Returns /// @@ -797,6 +799,7 @@ pub trait ParquetHandler: AsAny { &self, location: url::Url, data: Box>> + Send>, + _stats_columns: &[String], ) -> DeltaResult<()>; /// Read the footer metadata from a Parquet file without reading the data. diff --git a/kernel/src/scan/data_skipping/stats_schema.rs b/kernel/src/scan/data_skipping/stats_schema.rs index e18423b467..ff4fde8b69 100644 --- a/kernel/src/scan/data_skipping/stats_schema.rs +++ b/kernel/src/scan/data_skipping/stats_schema.rs @@ -133,6 +133,120 @@ pub(crate) fn expected_stats_schema( StructType::try_new(fields) } +/// Returns the list of column names that should have statistics collected. +/// +/// This extracts just the column names without building the full stats schema, +/// making it more efficient when only the column list is needed. +#[allow(unused)] +pub(crate) fn stats_column_names( + physical_file_schema: &Schema, + table_properties: &TableProperties, +) -> Vec { + let mut filter = StatsColumnFilter::new(table_properties); + let mut columns = Vec::new(); + filter.collect_columns(physical_file_schema, &mut columns); + columns +} + +/// Handles column filtering logic for statistics based on table properties. +/// +/// Filters columns according to: +/// * `dataSkippingStatsColumns` - explicit list of columns to include (takes precedence) +/// * `dataSkippingNumIndexedCols` - number of leaf columns to include (default 32) +struct StatsColumnFilter { + n_columns: Option, + added_columns: u64, + column_names: Option>, + path: Vec, +} + +impl StatsColumnFilter { + fn new(props: &TableProperties) -> Self { + // If data_skipping_stats_columns is specified, it takes precedence + // over data_skipping_num_indexed_cols, even if that is also specified. + if let Some(column_names) = &props.data_skipping_stats_columns { + Self { + n_columns: None, + added_columns: 0, + column_names: Some(column_names.clone()), + path: Vec::new(), + } + } else { + let n_cols = props + .data_skipping_num_indexed_cols + .unwrap_or(DataSkippingNumIndexedCols::NumColumns(32)); + Self { + n_columns: Some(n_cols), + added_columns: 0, + column_names: None, + path: Vec::new(), + } + } + } + + /// Collects column names that should have statistics. + fn collect_columns(&mut self, schema: &Schema, result: &mut Vec) { + for field in schema.fields() { + self.collect_field(field, result); + } + } + + fn collect_field(&mut self, field: &StructField, result: &mut Vec) { + if self.at_column_limit() { + return; + } + + self.path.push(field.name.clone()); + + match field.data_type() { + DataType::Struct(struct_type) => { + for child in struct_type.fields() { + self.collect_field(child, result); + } + } + _ => { + if self.should_include_current() { + result.push(ColumnName::new(&self.path)); + self.added_columns += 1; + } + } + } + + self.path.pop(); + } + + /// Returns true if the column limit has been reached. + fn at_column_limit(&self) -> bool { + matches!( + self.n_columns, + Some(DataSkippingNumIndexedCols::NumColumns(n)) if self.added_columns >= n + ) + } + + /// Returns true if the current path should be included based on column_names config. + fn should_include_current(&self) -> bool { + self.column_names + .as_ref() + .map(|ns| should_include_column(&ColumnName::new(&self.path), ns)) + .unwrap_or(true) + } + + /// Enters a field path for filtering decisions. + fn enter_field(&mut self, name: &str) { + self.path.push(name.to_string()); + } + + /// Exits the current field path. + fn exit_field(&mut self) { + self.path.pop(); + } + + /// Records that a leaf column was included. + fn record_included(&mut self) { + self.added_columns += 1; + } +} + /// Transforms a schema to make all fields nullable. /// Used for stats schemas where stats may not be available for all columns. pub(crate) struct NullableStatsTransform; @@ -178,44 +292,19 @@ impl<'a> SchemaTransform<'a> for NullCountStatsTransform { /// Base stats schema in this case refers the subsets of fields in the table schema /// that may be considered for stats collection. Depending on the type of stats - min/max/nullcount/... - /// additional transformations may be applied. +/// Transforms a schema to filter columns for statistics based on table properties. /// -/// The concrete shape of the schema depends on the table configuration. -/// * `dataSkippingStatsColumns` - used to explicitly specify the columns -/// to be used for data skipping statistics. (takes precedence) -/// * `dataSkippingNumIndexedCols` - used to specify the number of columns -/// to be used for data skipping statistics. Defaults to 32. -/// -/// All fields are nullable. +/// All fields in the output are nullable. #[allow(unused)] struct BaseStatsTransform { - n_columns: Option, - added_columns: u64, - column_names: Option>, - path: Vec, + filter: StatsColumnFilter, } impl BaseStatsTransform { #[allow(unused)] fn new(props: &TableProperties) -> Self { - // If data_skipping_stats_columns is specified, it takes precedence - // over data_skipping_num_indexed_cols, even if that is also specified. - if let Some(column_names) = &props.data_skipping_stats_columns { - Self { - n_columns: None, - added_columns: 0, - column_names: Some(column_names.clone()), - path: Vec::new(), - } - } else { - let n_cols = props - .data_skipping_num_indexed_cols - .unwrap_or(DataSkippingNumIndexedCols::NumColumns(32)); - Self { - n_columns: Some(n_cols), - added_columns: 0, - column_names: None, - path: Vec::new(), - } + Self { + filter: StatsColumnFilter::new(props), } } } @@ -224,34 +313,22 @@ impl<'a> SchemaTransform<'a> for BaseStatsTransform { fn transform_struct_field(&mut self, field: &'a StructField) -> Option> { use Cow::*; - // Check if the number of columns is set and if the added columns exceed the limit - // In the constructor we assert this will always be None if column_names are specified - if let Some(DataSkippingNumIndexedCols::NumColumns(n_cols)) = self.n_columns { - if self.added_columns >= n_cols { - return None; - } + if self.filter.at_column_limit() { + return None; } - self.path.push(field.name.clone()); + self.filter.enter_field(field.name()); let data_type = field.data_type(); // We always traverse struct fields (they don't count against the column limit), // but we only include leaf fields if they qualify based on column_names config. // When column_names is None, all leaf fields are included (up to n_columns limit). if !matches!(data_type, DataType::Struct(_)) { - let should_include = self - .column_names - .as_ref() - .map(|ns| should_include_column(&ColumnName::new(&self.path), ns)) - .unwrap_or(true); - - if !should_include { - self.path.pop(); + if !self.filter.should_include_current() { + self.filter.exit_field(); return None; } - - // Increment count only for leaf columns - self.added_columns += 1; + self.filter.record_included(); } let field = match self.transform(&field.data_type)? { @@ -264,7 +341,7 @@ impl<'a> SchemaTransform<'a> for BaseStatsTransform { }), }; - self.path.pop(); + self.filter.exit_field(); // exclude struct fields with no children if matches!(field.data_type(), DataType::Struct(dt) if dt.fields().len() == 0) { diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 195b64f31d..eb94292c2e 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -441,9 +441,11 @@ impl Snapshot { let data_iter = writer.checkpoint_data(engine)?; let state = data_iter.state(); let lazy_data = data_iter.map(|r| r.and_then(|f| f.apply_selection_vector())); - engine - .parquet_handler() - .write_parquet_file(checkpoint_path.clone(), Box::new(lazy_data))?; + engine.parquet_handler().write_parquet_file( + checkpoint_path.clone(), + Box::new(lazy_data), + &[], + )?; let file_meta = engine.storage_handler().head(&checkpoint_path)?; diff --git a/kernel/src/table_configuration.rs b/kernel/src/table_configuration.rs index 9080447084..cbbc8f6738 100644 --- a/kernel/src/table_configuration.rs +++ b/kernel/src/table_configuration.rs @@ -13,7 +13,8 @@ use std::sync::Arc; use url::Url; use crate::actions::{Metadata, Protocol}; -use crate::scan::data_skipping::stats_schema::expected_stats_schema; +use crate::expressions::ColumnName; +use crate::scan::data_skipping::stats_schema::{expected_stats_schema, stats_column_names}; use crate::schema::variant_utils::validate_variant_type_feature_support; use crate::schema::{InvariantChecker, SchemaRef, StructType}; use crate::table_features::{ @@ -181,21 +182,38 @@ impl TableConfiguration { #[allow(unused)] #[internal_api] pub(crate) fn expected_stats_schema(&self) -> DeltaResult { + let physical_schema = self.physical_data_schema(); + Ok(Arc::new(expected_stats_schema( + &physical_schema, + self.table_properties(), + )?)) + } + + /// Returns the list of column names that should have statistics collected. + /// + /// This returns the leaf column paths as a flat list of column names + /// (e.g., `["id", "nested.field"]`). + #[allow(unused)] + #[internal_api] + pub(crate) fn stats_column_names(&self) -> Vec { + let physical_schema = self.physical_data_schema(); + stats_column_names(&physical_schema, self.table_properties()) + } + + /// Returns the physical schema for data columns (excludes partition columns). + /// + /// Partition columns are excluded because statistics are only collected for data columns + /// that are physically stored in the parquet files. Partition values are stored in the + /// file path, not in the file content, so they don't have file-level statistics. + fn physical_data_schema(&self) -> StructType { let partition_columns = self.metadata().partition_columns(); let column_mapping_mode = self.column_mapping_mode(); - // Partition columns are excluded because statistics are only collected for data columns - // that are physically stored in the parquet files. Partition values are stored in the - // file path, not in the file content, so they don't have file-level statistics. - let physical_schema = StructType::try_new( + StructType::new_unchecked( self.schema() .fields() .filter(|field| !partition_columns.contains(field.name())) .map(|field| field.make_physical(column_mapping_mode)), - )?; - Ok(Arc::new(expected_stats_schema( - &physical_schema, - self.table_properties(), - )?)) + ) } /// The [`Metadata`] for this table at this version. diff --git a/kernel/src/transaction/mod.rs b/kernel/src/transaction/mod.rs index a3041b1893..e7e27c3c2a 100644 --- a/kernel/src/transaction/mod.rs +++ b/kernel/src/transaction/mod.rs @@ -41,6 +41,65 @@ use crate::{ }; use delta_kernel_derive::internal_api; +/// Visitor to validate statistics in add file metadata. +/// Uses RowVisitor pattern to extract and validate stats from EngineData. +struct StatsValidationVisitor { + rows_validated: usize, + rows_with_num_records: usize, + errors: Vec, +} + +impl StatsValidationVisitor { + fn new() -> Self { + Self { + rows_validated: 0, + rows_with_num_records: 0, + errors: Vec::new(), + } + } + + fn validate(&self) -> DeltaResult<()> { + if self.rows_validated == 0 { + return Err(Error::generic("No rows to validate")); + } + if self.rows_with_num_records == 0 { + // This is a warning case, not an error - stats might be missing + // but we don't fail the commit for it + } + if !self.errors.is_empty() { + return Err(Error::generic(format!( + "Stats validation errors: {}", + self.errors.join("; ") + ))); + } + Ok(()) + } +} + +impl RowVisitor for StatsValidationVisitor { + fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]) { + static NAMES: LazyLock> = LazyLock::new(|| vec![column_name!("stats")]); + static TYPES: LazyLock> = LazyLock::new(|| { + vec![DataType::STRING] // stats is a JSON string + }); + (NAMES.as_slice(), TYPES.as_slice()) + } + + fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()> { + for row_index in 0..row_count { + self.rows_validated += 1; + if let Some(stats_str) = getters[0].get_opt(row_index, "stats")? { + let stats_str: String = stats_str; + // Check if stats has numRecords + if stats_str.contains("\"numRecords\"") { + self.rows_with_num_records += 1; + } + } + } + Ok(()) + } +} + /// Type alias for an iterator of [`EngineData`] results. pub(crate) type EngineDataResultIterator<'a> = Box>> + Send + 'a>; @@ -333,6 +392,14 @@ impl Transaction { /// transaction in case of a conflict so the user can retry, etc.) /// - Err(Error) indicates a non-retryable error (e.g. logic/validation error). pub fn commit(self, engine: &dyn Engine) -> DeltaResult { + // Step 0: Validate stats in add file metadata + // This ensures all files have valid stats before committing + for add_metadata in &self.add_files_metadata { + let mut validator = StatsValidationVisitor::new(); + validator.visit_rows_of(add_metadata.as_ref())?; + validator.validate()?; + } + // Step 1: Check for duplicate app_ids and generate set transactions (`txn`) // Note: The commit info must always be the first action in the commit but we generate it in // step 2 to fail early on duplicate transaction appIds @@ -790,6 +857,45 @@ impl Transaction { &BASE_ADD_FILES_SCHEMA } + /// Returns the expected schema for file statistics. + /// + /// The schema structure is derived from table configuration: + /// - `delta.dataSkippingStatsColumns`: Explicit column list (if set) + /// - `delta.dataSkippingNumIndexedCols`: Column count limit (default 32) + /// - Partition columns: Always excluded + /// + /// The returned schema has the following structure: + /// ```ignore + /// { + /// numRecords: long, + /// nullCount: { ... }, // Nested struct mirroring data schema, all fields LONG + /// minValues: { ... }, // Nested struct, only min/max eligible types + /// maxValues: { ... }, // Nested struct, only min/max eligible types + /// tightBounds: boolean, + /// } + /// ``` + /// + /// Engines should collect statistics matching this schema structure when writing files. + #[allow(unused)] + pub fn stats_schema(&self) -> DeltaResult { + self.read_snapshot + .table_configuration() + .expected_stats_schema() + } + + /// Returns the list of column names that should have statistics collected. + /// + /// This returns the leaf column paths as a flat list of column names + /// (e.g., `["id", "nested.field"]`). + /// + /// Engines can use this to determine which columns need stats during writes. + #[allow(unused)] + pub fn stats_columns(&self) -> Vec { + self.read_snapshot + .table_configuration() + .stats_column_names() + } + // Generate the logical-to-physical transform expression which must be evaluated on every data // chunk before writing. At the moment, this is a transaction-wide expression. fn generate_logical_to_physical(&self) -> Expression { @@ -820,6 +926,8 @@ impl Transaction { // Note: after we introduce metadata updates (modify table schema, etc.), we need to make sure // that engines cannot call this method after a metadata change, since the write context could // have invalid metadata. + // Note: Callers that use get_write_context may be writing data to the table and they might + // have invalid metadata. pub fn get_write_context(&self) -> WriteContext { let target_dir = self.read_snapshot.table_root(); let snapshot_schema = self.read_snapshot.schema(); @@ -837,11 +945,19 @@ impl Transaction { .cloned(); let physical_schema = Arc::new(StructType::new_unchecked(physical_fields)); + // Get stats columns from table configuration + let stats_columns = self + .stats_columns() + .into_iter() + .map(|c| c.to_string()) + .collect(); + WriteContext::new( target_dir.clone(), snapshot_schema, physical_schema, Arc::new(logical_to_physical), + stats_columns, ) } @@ -854,6 +970,37 @@ impl Transaction { self.add_files_metadata.push(add_metadata); } + /// Add files with statistics validation. + /// + /// Similar to [`add_files`], but validates that the metadata contains valid statistics + /// before adding. Returns an error if validation fails. + /// + /// [`add_files`]: Transaction::add_files + #[allow(unused)] + pub fn add_files_validated(&mut self, add_metadata: Box) -> DeltaResult<()> { + let mut validator = StatsValidationVisitor::new(); + validator.visit_rows_of(add_metadata.as_ref())?; + validator.validate()?; + self.add_files_metadata.push(add_metadata); + Ok(()) + } + + /// Validate statistics in add file metadata without modifying the transaction. + /// + /// Returns a summary string describing the validation results. + #[allow(unused)] + pub fn validate_add_files_stats(add_metadata: &dyn EngineData) -> DeltaResult { + let mut validator = StatsValidationVisitor::new(); + validator.visit_rows_of(add_metadata)?; + validator.validate()?; + Ok(format!( + "Validated {} rows, {} had numRecords. Errors: {}", + validator.rows_validated, + validator.rows_with_num_records, + validator.errors.join("; ") + )) + } + /// Generate add actions, handling row tracking internally if needed fn generate_adds<'a>( &'a self, @@ -1313,6 +1460,8 @@ pub struct WriteContext { logical_schema: SchemaRef, physical_schema: SchemaRef, logical_to_physical: ExpressionRef, + /// Column names that should have statistics collected during writes. + stats_columns: Vec, } impl WriteContext { @@ -1321,12 +1470,14 @@ impl WriteContext { logical_schema: SchemaRef, physical_schema: SchemaRef, logical_to_physical: ExpressionRef, + stats_columns: Vec, ) -> Self { WriteContext { target_dir, logical_schema, physical_schema, logical_to_physical, + stats_columns, } } @@ -1346,6 +1497,13 @@ impl WriteContext { self.logical_to_physical.clone() } + /// Returns the column names that should have statistics collected during writes. + /// + /// Based on table configuration (dataSkippingNumIndexedCols, dataSkippingStatsColumns). + pub fn stats_columns(&self) -> &[String] { + &self.stats_columns + } + /// Generate a new unique absolute URL for a deletion vector file. /// /// This method generates a unique file name in the table directory. diff --git a/kernel/tests/write.rs b/kernel/tests/write.rs index f47db4b92e..8e53056367 100644 --- a/kernel/tests/write.rs +++ b/kernel/tests/write.rs @@ -393,7 +393,7 @@ async fn test_append() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": true, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":1},\"maxValues\":{\"number\":3},\"tightBounds\":true}" } }), json!({ @@ -403,7 +403,7 @@ async fn test_append() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": true, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":4},\"maxValues\":{\"number\":6},\"tightBounds\":true}" } }), ]; @@ -601,7 +601,7 @@ async fn test_append_partitioned() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": false, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":1},\"maxValues\":{\"number\":3},\"tightBounds\":true}" } }), json!({ @@ -613,7 +613,7 @@ async fn test_append_partitioned() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": false, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":4},\"maxValues\":{\"number\":6},\"tightBounds\":true}" } }), ]; @@ -1078,6 +1078,7 @@ async fn test_append_variant() -> Result<(), Box> { write_context.target_dir(), Box::new(ArrowEngineData::new(data.clone())), HashMap::new(), + write_context.stats_columns(), ) .await?; @@ -1251,6 +1252,7 @@ async fn test_shredded_variant_read_rejection() -> Result<(), Box