diff --git a/kernel/src/engine/default/mod.rs b/kernel/src/engine/default/mod.rs index 430f10a30e..cb01d3b022 100644 --- a/kernel/src/engine/default/mod.rs +++ b/kernel/src/engine/default/mod.rs @@ -33,6 +33,7 @@ pub mod file_stream; pub mod filesystem; pub mod json; pub mod parquet; +pub mod stats; pub mod storage; /// Converts a Stream-producing future to a synchronous iterator. @@ -216,7 +217,12 @@ impl DefaultEngine { )?; let physical_data = logical_to_physical_expr.evaluate(data)?; self.parquet - .write_parquet_file(write_context.target_dir(), physical_data, partition_values) + .write_parquet_file( + write_context.target_dir(), + physical_data, + partition_values, + Some(write_context.stats_columns()), + ) .await } } diff --git a/kernel/src/engine/default/parquet.rs b/kernel/src/engine/default/parquet.rs index 7ed225ddb3..4ca0e2d431 100644 --- a/kernel/src/engine/default/parquet.rs +++ b/kernel/src/engine/default/parquet.rs @@ -7,8 +7,8 @@ use std::sync::Arc; use delta_kernel_derive::internal_api; use crate::arrow::array::builder::{MapBuilder, MapFieldNames, StringBuilder}; -use crate::arrow::array::{Int64Array, RecordBatch, StringArray, StructArray}; -use crate::arrow::datatypes::{DataType, Field}; +use crate::arrow::array::{Array, Int64Array, RecordBatch, StringArray, StructArray}; +use crate::arrow::datatypes::{DataType, Field, Schema}; use crate::parquet::arrow::arrow_reader::{ ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, }; @@ -23,6 +23,7 @@ use object_store::{DynObjectStore, ObjectStore}; use uuid::Uuid; use super::file_stream::{FileOpenFuture, FileOpener, FileStream}; +use super::stats::collect_stats; use super::UrlExt; use crate::engine::arrow_conversion::{TryFromArrow as _, TryIntoArrow as _}; use crate::engine::arrow_data::ArrowEngineData; @@ -46,22 +47,16 @@ pub struct DefaultParquetHandler { } /// Metadata of a data file (typically a parquet file). -/// -/// Currently just includes the the number of records as statistics, but will expand to include -/// more statistics and other metadata in the future. #[derive(Debug)] pub struct DataFileMetadata { file_meta: FileMeta, - // NB: We use usize instead of u64 since arrow uses usize for record batch sizes - num_records: usize, + /// Collected statistics for this file (includes numRecords, tightBounds, etc.). + stats: StructArray, } impl DataFileMetadata { - pub fn new(file_meta: FileMeta, num_records: usize) -> Self { - Self { - file_meta, - num_records, - } + pub fn new(file_meta: FileMeta, stats: StructArray) -> Self { + Self { file_meta, stats } } /// Convert DataFileMetadata into a record batch which matches the schema returned by @@ -80,7 +75,8 @@ impl DataFileMetadata { last_modified, size, }, - num_records, + stats, + .. } = self; // create the record batch of the write metadata let path = Arc::new(StringArray::from(vec![location.to_string()])); @@ -104,20 +100,35 @@ impl DataFileMetadata { .map_err(|_| Error::generic("Failed to convert parquet metadata 'size' to i64"))?; let size = Arc::new(Int64Array::from(vec![size])); let modification_time = Arc::new(Int64Array::from(vec![*last_modified])); - let stats = Arc::new(StructArray::try_new_with_length( - vec![Field::new("numRecords", DataType::Int64, true)].into(), - vec![Arc::new(Int64Array::from(vec![*num_records as i64]))], - None, - 1, - )?); - Ok(Box::new(ArrowEngineData::new(RecordBatch::try_new( - Arc::new( - crate::transaction::BASE_ADD_FILES_SCHEMA - .as_ref() - .try_into_arrow()?, + let stats_array = Arc::new(stats.clone()); + + // Build schema dynamically based on stats (stats schema varies based on collected statistics) + let key_value_struct = DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ] + .into(), + ); + let schema = Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "partitionValues", + DataType::Map( + Arc::new(Field::new("key_value", key_value_struct, false)), + false, + ), + false, ), - vec![path, partitions, size, modification_time, stats], + Field::new("size", DataType::Int64, false), + Field::new("modificationTime", DataType::Int64, false), + Field::new("stats", stats_array.data_type().clone(), true), + ]); + + Ok(Box::new(ArrowEngineData::new(RecordBatch::try_new( + Arc::new(schema), + vec![path, partitions, size, modification_time, stats_array], )?))) } } @@ -148,10 +159,13 @@ impl DefaultParquetHandler { &self, path: &url::Url, data: Box, + stats_columns: &[String], ) -> DeltaResult { let batch: Box<_> = ArrowEngineData::try_from_engine_data(data)?; let record_batch = batch.record_batch(); - let num_records = record_batch.num_rows(); + + // Collect statistics before writing (includes numRecords) + let stats = collect_stats(record_batch, stats_columns)?; let mut buffer = vec![]; let mut writer = ArrowWriter::try_new(&mut buffer, record_batch.schema(), None)?; @@ -185,7 +199,7 @@ impl DefaultParquetHandler { } let file_meta = FileMeta::new(path, modification_time, size); - Ok(DataFileMetadata::new(file_meta, num_records)) + Ok(DataFileMetadata::new(file_meta, stats)) } /// Write `data` to `{path}/.parquet` as parquet using ArrowWriter and return the parquet @@ -201,8 +215,11 @@ impl DefaultParquetHandler { path: &url::Url, data: Box, partition_values: HashMap, + stats_columns: Option<&[String]>, ) -> DeltaResult> { - let parquet_metadata = self.write_parquet(path, data).await?; + let parquet_metadata = self + .write_parquet(path, data, stats_columns.unwrap_or(&[])) + .await?; parquet_metadata.as_record_batch(&partition_values) } } @@ -294,6 +311,7 @@ impl ParquetHandler for DefaultParquetHandler { /// - `location` - The full URL path where the Parquet file should be written /// (e.g., `s3://bucket/path/file.parquet`, `file:///path/to/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Optional column names for which statistics should be collected. /// /// # Returns /// @@ -302,6 +320,7 @@ impl ParquetHandler for DefaultParquetHandler { &self, location: url::Url, mut data: Box>> + Send>, + _stats_columns: Option<&[String]>, ) -> DeltaResult<()> { let store = self.store.clone(); @@ -606,19 +625,26 @@ mod tests { let last_modified = 10000000000; let num_records = 10; let file_metadata = FileMeta::new(location.clone(), last_modified, size); - let data_file_metadata = DataFileMetadata::new(file_metadata, num_records); + let stats = StructArray::try_new( + vec![ + Field::new("numRecords", ArrowDataType::Int64, true), + Field::new("tightBounds", ArrowDataType::Boolean, true), + ] + .into(), + vec![ + Arc::new(Int64Array::from(vec![num_records as i64])), + Arc::new(BooleanArray::from(vec![true])), + ], + None, + ) + .unwrap(); + let data_file_metadata = DataFileMetadata::new(file_metadata, stats.clone()); let partition_values = HashMap::from([("partition1".to_string(), "a".to_string())]); let actual = data_file_metadata .as_record_batch(&partition_values) .unwrap(); let actual = ArrowEngineData::try_from_engine_data(actual).unwrap(); - let schema = Arc::new( - crate::transaction::BASE_ADD_FILES_SCHEMA - .as_ref() - .try_into_arrow() - .unwrap(), - ); let mut partition_values_builder = MapBuilder::new( Some(MapFieldNames { entry: "key_value".to_string(), @@ -632,13 +658,33 @@ mod tests { partition_values_builder.values().append_value("a"); partition_values_builder.append(true).unwrap(); let partition_values = partition_values_builder.finish(); - let stats_struct = StructArray::try_new_with_length( - vec![Field::new("numRecords", ArrowDataType::Int64, true)].into(), - vec![Arc::new(Int64Array::from(vec![num_records as i64]))], - None, - 1, - ) - .unwrap(); + + // Build expected schema dynamically based on stats + let stats_field = Field::new("stats", stats.data_type().clone(), true); + let schema = Arc::new(crate::arrow::datatypes::Schema::new(vec![ + Field::new("path", ArrowDataType::Utf8, false), + Field::new( + "partitionValues", + ArrowDataType::Map( + Arc::new(Field::new( + "key_value", + ArrowDataType::Struct( + vec![ + Field::new("key", ArrowDataType::Utf8, false), + Field::new("value", ArrowDataType::Utf8, true), + ] + .into(), + ), + false, + )), + false, + ), + false, + ), + Field::new("size", ArrowDataType::Int64, false), + Field::new("modificationTime", ArrowDataType::Int64, false), + stats_field, + ])); let expected = RecordBatch::try_new( schema, @@ -647,7 +693,7 @@ mod tests { Arc::new(partition_values), Arc::new(Int64Array::from(vec![size as i64])), Arc::new(Int64Array::from(vec![last_modified])), - Arc::new(stats_struct), + Arc::new(stats), ], ) .unwrap(); @@ -670,7 +716,7 @@ mod tests { )); let write_metadata = parquet_handler - .write_parquet(&Url::parse("memory:///data/").unwrap(), data) + .write_parquet(&Url::parse("memory:///data/").unwrap(), data, &[]) .await .unwrap(); @@ -681,7 +727,7 @@ mod tests { last_modified, size, }, - num_records, + ref stats, } = write_metadata; let expected_location = Url::parse("memory:///data/").unwrap(); @@ -699,6 +745,15 @@ mod tests { assert_eq!(&expected_location.join(filename).unwrap(), location); assert_eq!(expected_size, size); assert!(now - last_modified < 10_000); + + // Check numRecords from stats + let num_records = stats + .column_by_name("numRecords") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); assert_eq!(num_records, 3); // check we can read back @@ -741,7 +796,7 @@ mod tests { assert_result_error_with_message( parquet_handler - .write_parquet(&Url::parse("memory:///data").unwrap(), data) + .write_parquet(&Url::parse("memory:///data").unwrap(), data, &[]) .await, "Generic delta kernel error: Path must end with a trailing slash: memory:///data", ); @@ -776,7 +831,7 @@ mod tests { // Test writing through the trait method let file_url = Url::parse("memory:///test/data.parquet").unwrap(); parquet_handler - .write_parquet_file(file_url.clone(), data_iter) + .write_parquet_file(file_url.clone(), data_iter, None) .unwrap(); // Verify we can read the file back @@ -964,7 +1019,7 @@ mod tests { // Write the data let file_url = Url::parse("memory:///roundtrip/test.parquet").unwrap(); parquet_handler - .write_parquet_file(file_url.clone(), data_iter) + .write_parquet_file(file_url.clone(), data_iter, None) .unwrap(); // Read it back @@ -1152,7 +1207,7 @@ mod tests { // Write the first file parquet_handler - .write_parquet_file(file_url.clone(), data_iter1) + .write_parquet_file(file_url.clone(), data_iter1, None) .unwrap(); // Create second data set with different data @@ -1168,7 +1223,7 @@ mod tests { // Overwrite with second file (overwrite=true) parquet_handler - .write_parquet_file(file_url.clone(), data_iter2) + .write_parquet_file(file_url.clone(), data_iter2, None) .unwrap(); // Read back and verify it contains the second data set @@ -1231,7 +1286,7 @@ mod tests { // Write the first file parquet_handler - .write_parquet_file(file_url.clone(), data_iter1) + .write_parquet_file(file_url.clone(), data_iter1, None) .unwrap(); // Create second data set @@ -1247,7 +1302,7 @@ mod tests { // Write again - should overwrite successfully (new behavior always overwrites) parquet_handler - .write_parquet_file(file_url.clone(), data_iter2) + .write_parquet_file(file_url.clone(), data_iter2, None) .unwrap(); // Verify the file was overwritten with the new data diff --git a/kernel/src/engine/default/stats.rs b/kernel/src/engine/default/stats.rs new file mode 100644 index 0000000000..f89fd4f6c5 --- /dev/null +++ b/kernel/src/engine/default/stats.rs @@ -0,0 +1,724 @@ +//! Statistics collection for Delta Lake file writes. +//! +//! Provides `collect_stats` to compute min, max, and null count statistics +//! for a single RecordBatch during file writes. + +use std::collections::HashSet; +use std::sync::Arc; + +use crate::arrow::array::{ + Array, ArrayRef, BooleanArray, Decimal128Array, Int64Array, LargeStringArray, PrimitiveArray, + RecordBatch, StringArray, StringViewArray, StructArray, +}; +use crate::arrow::compute::kernels::aggregate::{max, max_string, min, min_string}; +use crate::arrow::datatypes::{ + ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, Fields, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, +}; +use crate::{DeltaResult, Error}; + +/// Maximum prefix length for string statistics (Delta protocol requirement). +const STRING_PREFIX_LENGTH: usize = 32; + +// ============================================================================ +// Min/Max computation using Arrow compute kernels +// ============================================================================ + +/// Aggregation type selector. +#[derive(Clone, Copy)] +enum Agg { + Min, + Max, +} + +/// Truncate string to maximum prefix length for Delta statistics. +fn truncate_string(s: &str) -> String { + s.chars().take(STRING_PREFIX_LENGTH).collect() +} + +/// Downcast helper with descriptive error message. +fn downcast(column: &ArrayRef) -> DeltaResult<&T> { + column.as_any().downcast_ref::().ok_or_else(|| { + Error::generic(format!( + "Failed to downcast column to {}", + std::any::type_name::(), + )) + }) +} + +/// Compute aggregation for a primitive array. +fn agg_primitive(column: &ArrayRef, agg: Agg) -> DeltaResult> +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, + PrimitiveArray: From>>, +{ + let array = downcast::>(column)?; + let result = match agg { + Agg::Min => min(array), + Agg::Max => max(array), + }; + Ok(result.map(|v| Arc::new(PrimitiveArray::::from(vec![Some(v)])) as ArrayRef)) +} + +/// Compute aggregation for a timestamp array, preserving timezone. +fn agg_timestamp( + column: &ArrayRef, + tz: Option>, + agg: Agg, +) -> DeltaResult> +where + T: crate::arrow::datatypes::ArrowTimestampType, + PrimitiveArray: From>>, +{ + let array = downcast::>(column)?; + let result = match agg { + Agg::Min => min(array), + Agg::Max => max(array), + }; + Ok(result.map(|v| { + Arc::new(PrimitiveArray::::from(vec![Some(v)]).with_timezone_opt(tz)) as ArrayRef + })) +} + +/// Compute aggregation for a decimal128 array, preserving precision and scale. +fn agg_decimal( + column: &ArrayRef, + precision: u8, + scale: i8, + agg: Agg, +) -> DeltaResult> { + let array = downcast::(column)?; + let result = match agg { + Agg::Min => min(array), + Agg::Max => max(array), + }; + result + .map(|v| { + Decimal128Array::from(vec![Some(v)]) + .with_precision_and_scale(precision, scale) + .map(|arr| Arc::new(arr) as ArrayRef) + }) + .transpose() + .map_err(|e| Error::generic(format!("Invalid decimal precision/scale: {e}"))) +} + +/// Compute aggregation for a string array with truncation. +fn agg_string(column: &ArrayRef, agg: Agg) -> DeltaResult> { + let array = downcast::(column)?; + let result = match agg { + Agg::Min => min_string(array), + Agg::Max => max_string(array), + }; + Ok(result.map(|s| Arc::new(StringArray::from(vec![Some(truncate_string(s))])) as ArrayRef)) +} + +/// Compute aggregation for a large string array with truncation. +fn agg_large_string(column: &ArrayRef, agg: Agg) -> DeltaResult> { + let array = downcast::(column)?; + let result = match agg { + Agg::Min => array.iter().flatten().min(), + Agg::Max => array.iter().flatten().max(), + }; + Ok( + result + .map(|s| Arc::new(LargeStringArray::from(vec![Some(truncate_string(s))])) as ArrayRef), + ) +} + +/// Compute aggregation for a string view array with truncation. +fn agg_string_view(column: &ArrayRef, agg: Agg) -> DeltaResult> { + let array = downcast::(column)?; + let result: Option<&str> = match agg { + Agg::Min => array.iter().flatten().min(), + Agg::Max => array.iter().flatten().max(), + }; + Ok(result.map(|s| Arc::new(StringViewArray::from(vec![Some(truncate_string(s))])) as ArrayRef)) +} + +/// Compute aggregation for a struct array by recursively processing child fields. +fn agg_struct(column: &ArrayRef, fields: &Fields, agg: Agg) -> DeltaResult> { + let struct_array = downcast::(column)?; + + let mut result_fields = Vec::new(); + let mut result_arrays = Vec::new(); + + for (i, field) in fields.iter().enumerate() { + let child = struct_array.column(i); + if let Some(child_agg) = compute_agg(child, agg)? { + result_fields.push(Field::new( + field.name(), + child_agg.data_type().clone(), + true, + )); + result_arrays.push(child_agg); + } + } + + if result_fields.is_empty() { + Ok(None) + } else { + Ok(Some(Arc::new( + StructArray::try_new(result_fields.into(), result_arrays, None) + .map_err(|e| Error::generic(format!("Failed to create struct: {e}")))?, + ) as ArrayRef)) + } +} + +/// Compute min or max for a column based on its data type. +fn compute_agg(column: &ArrayRef, agg: Agg) -> DeltaResult> { + match column.data_type() { + // Nested struct - recurse into children + DataType::Struct(fields) => agg_struct(column, fields, agg), + + // Integer types + DataType::Int8 => agg_primitive::(column, agg), + DataType::Int16 => agg_primitive::(column, agg), + DataType::Int32 => agg_primitive::(column, agg), + DataType::Int64 => agg_primitive::(column, agg), + DataType::UInt8 => agg_primitive::(column, agg), + DataType::UInt16 => agg_primitive::(column, agg), + DataType::UInt32 => agg_primitive::(column, agg), + DataType::UInt64 => agg_primitive::(column, agg), + + // Float types + DataType::Float32 => agg_primitive::(column, agg), + DataType::Float64 => agg_primitive::(column, agg), + + // Date types + DataType::Date32 => agg_primitive::(column, agg), + DataType::Date64 => agg_primitive::(column, agg), + + // Timestamp types (preserve timezone) + DataType::Timestamp(TimeUnit::Second, tz) => { + agg_timestamp::(column, tz.clone(), agg) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + agg_timestamp::(column, tz.clone(), agg) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + agg_timestamp::(column, tz.clone(), agg) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + agg_timestamp::(column, tz.clone(), agg) + } + + // Decimal type (preserve precision/scale) + DataType::Decimal128(p, s) => agg_decimal(column, *p, *s, agg), + + // String types (with truncation) + DataType::Utf8 => agg_string(column, agg), + DataType::LargeUtf8 => agg_large_string(column, agg), + DataType::Utf8View => agg_string_view(column, agg), + + // Unsupported types return no min/max + _ => Ok(None), + } +} + +// ============================================================================ +// Null count computation +// ============================================================================ + +/// Compute null count for a column, handling nested structs recursively. +fn compute_null_count(column: &ArrayRef) -> DeltaResult { + match column.data_type() { + DataType::Struct(fields) => { + let struct_array = downcast::(column)?; + let children: Vec = (0..fields.len()) + .map(|i| compute_null_count(struct_array.column(i))) + .collect::>>()?; + let null_count_fields: Fields = fields + .iter() + .map(|f| Field::new(f.name(), null_count_data_type(f.data_type()), true)) + .collect(); + Ok(Arc::new( + StructArray::try_new(null_count_fields, children, None) + .map_err(|e| Error::generic(format!("null count struct: {e}")))?, + )) + } + _ => Ok(Arc::new(Int64Array::from(vec![column.null_count() as i64]))), + } +} + +/// Get the data type for null counts of a given data type. +fn null_count_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Struct(fields) => { + let null_count_fields: Vec = fields + .iter() + .map(|f| Field::new(f.name(), null_count_data_type(f.data_type()), true)) + .collect(); + DataType::Struct(null_count_fields.into()) + } + _ => DataType::Int64, + } +} + +// ============================================================================ +// Public API +// ============================================================================ + +/// Accumulates (field_name, array) pairs for building a stats struct. +struct StatsAccumulator { + name: &'static str, + fields: Vec, + arrays: Vec, +} + +impl StatsAccumulator { + fn new(name: &'static str) -> Self { + Self { + name, + fields: Vec::new(), + arrays: Vec::new(), + } + } + + fn push(&mut self, field_name: &str, array: ArrayRef) { + self.fields + .push(Field::new(field_name, array.data_type().clone(), true)); + self.arrays.push(array); + } + + fn build(self) -> DeltaResult)>> { + if self.fields.is_empty() { + return Ok(None); + } + let struct_arr = StructArray::try_new(self.fields.into(), self.arrays, None) + .map_err(|e| Error::generic(format!("Failed to create {}: {e}", self.name)))?; + let field = Field::new(self.name, struct_arr.data_type().clone(), true); + Ok(Some((field, Arc::new(struct_arr) as Arc))) + } +} + +/// Collect statistics from a RecordBatch for Delta Lake file statistics. +/// +/// Returns a StructArray with the following fields: +/// - `numRecords`: total row count +/// - `nullCount`: nested struct with null counts per column +/// - `minValues`: nested struct with min values per column +/// - `maxValues`: nested struct with max values per column +/// - `tightBounds`: always true for new file writes +/// +/// # Arguments +/// * `batch` - The RecordBatch to collect statistics from +/// * `stats_columns` - Column names that should have statistics collected (allowlist). +/// Only these columns will appear in nullCount/minValues/maxValues. +pub(crate) fn collect_stats( + batch: &RecordBatch, + stats_columns: &[String], +) -> DeltaResult { + let stats_set: HashSet<&str> = stats_columns.iter().map(|s| s.as_str()).collect(); + let schema = batch.schema(); + + // Collect all stats in a single pass over columns + let mut null_counts = StatsAccumulator::new("nullCount"); + let mut min_values = StatsAccumulator::new("minValues"); + let mut max_values = StatsAccumulator::new("maxValues"); + + for (col_idx, field) in schema.fields().iter().enumerate() { + if !stats_set.contains(field.name().as_str()) { + continue; + } + + let column = batch.column(col_idx); + null_counts.push(field.name(), compute_null_count(column)?); + + if let Some(arr) = compute_agg(column, Agg::Min)? { + min_values.push(field.name(), arr); + } + if let Some(arr) = compute_agg(column, Agg::Max)? { + max_values.push(field.name(), arr); + } + } + + // Build output struct + let mut fields = vec![Field::new("numRecords", DataType::Int64, true)]; + let mut arrays: Vec> = + vec![Arc::new(Int64Array::from(vec![batch.num_rows() as i64]))]; + + for acc in [null_counts, min_values, max_values] { + if let Some((field, array)) = acc.build()? { + fields.push(field); + arrays.push(array); + } + } + + // tightBounds + fields.push(Field::new("tightBounds", DataType::Boolean, true)); + arrays.push(Arc::new(BooleanArray::from(vec![true]))); + + StructArray::try_new(fields.into(), arrays, None) + .map_err(|e| Error::generic(format!("Failed to create stats struct: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::array::{Array, Int64Array, StringArray}; + use crate::arrow::datatypes::Schema; + + #[test] + fn test_collect_stats_single_batch() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap(); + + let stats = collect_stats(&batch, &["id".to_string()]).unwrap(); + + assert_eq!(stats.len(), 1); + let num_records = stats + .column_by_name("numRecords") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(num_records.value(0), 3); + } + + #[test] + fn test_collect_stats_null_counts() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + + let stats = collect_stats(&batch, &["id".to_string(), "value".to_string()]).unwrap(); + + // Check nullCount struct + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // id has 0 nulls + let id_null_count = null_count + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id_null_count.value(0), 0); + + // value has 1 null + let value_null_count = null_count + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_null_count.value(0), 1); + } + + #[test] + fn test_collect_stats_respects_stats_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + + // Only collect stats for "id", not "value" + let stats = collect_stats(&batch, &["id".to_string()]).unwrap(); + + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Only id should be present + assert!(null_count.column_by_name("id").is_some()); + assert!(null_count.column_by_name("value").is_none()); + } + + #[test] + fn test_collect_stats_min_max() { + let schema = Arc::new(Schema::new(vec![ + Field::new("number", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![5, 1, 9, 3])), + Arc::new(StringArray::from(vec![ + Some("banana"), + Some("apple"), + Some("cherry"), + None, + ])), + ], + ) + .unwrap(); + + let stats = collect_stats(&batch, &["number".to_string(), "name".to_string()]).unwrap(); + + // Check minValues + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let number_min = min_values + .column_by_name("number") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(number_min.value(0), 1); + + let name_min = min_values + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_min.value(0), "apple"); + + // Check maxValues + let max_values = stats + .column_by_name("maxValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let number_max = max_values + .column_by_name("number") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(number_max.value(0), 9); + + let name_max = max_values + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_max.value(0), "cherry"); + } + + #[test] + fn test_collect_stats_all_nulls() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int64Array::from(vec![ + None as Option, + None, + None, + ]))], + ) + .unwrap(); + + let stats = collect_stats(&batch, &["value".to_string()]).unwrap(); + + // numRecords should be 3 + let num_records = stats + .column_by_name("numRecords") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(num_records.value(0), 3); + + // nullCount should be 3 + let null_count = stats + .column_by_name("nullCount") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value_null_count = null_count + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_null_count.value(0), 3); + + // minValues/maxValues should not have "value" field (all nulls) + if let Some(min_values) = stats.column_by_name("minValues") { + let min_struct = min_values.as_any().downcast_ref::().unwrap(); + assert!(min_struct.column_by_name("value").is_none()); + } + } + + #[test] + fn test_collect_stats_empty_stats_columns() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap(); + + // No stats columns requested + let stats = collect_stats(&batch, &[]).unwrap(); + + // Should still have numRecords and tightBounds + assert!(stats.column_by_name("numRecords").is_some()); + assert!(stats.column_by_name("tightBounds").is_some()); + + // Should not have nullCount, minValues, maxValues + assert!(stats.column_by_name("nullCount").is_none()); + assert!(stats.column_by_name("minValues").is_none()); + assert!(stats.column_by_name("maxValues").is_none()); + } + + #[test] + fn test_collect_stats_string_truncation() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + + // Create a string longer than 32 characters + let long_string = "a".repeat(50); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(StringArray::from(vec![long_string.as_str()]))], + ) + .unwrap(); + + let stats = collect_stats(&batch, &["text".to_string()]).unwrap(); + + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let text_min = min_values + .column_by_name("text") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Should be truncated to 32 chars + assert_eq!(text_min.value(0).len(), 32); + } + + #[test] + fn test_collect_stats_nested_struct() { + // Schema: { nested: { a: int64, b: string } } + let nested_fields = Fields::from(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), + ]); + let schema = Arc::new(Schema::new(vec![Field::new( + "nested", + DataType::Struct(nested_fields.clone()), + false, + )])); + + // Build nested struct data + let a_array = Arc::new(Int64Array::from(vec![10, 5, 20])); + let b_array = Arc::new(StringArray::from(vec![Some("zebra"), Some("apple"), None])); + let nested_struct = StructArray::try_new( + nested_fields, + vec![a_array as ArrayRef, b_array as ArrayRef], + None, + ) + .unwrap(); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(nested_struct) as ArrayRef]).unwrap(); + + let stats = collect_stats(&batch, &["nested".to_string()]).unwrap(); + + // Check minValues.nested.a = 5, minValues.nested.b = "apple" + let min_values = stats + .column_by_name("minValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let nested_min = min_values + .column_by_name("nested") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let a_min = nested_min + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_min.value(0), 5); + + let b_min = nested_min + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_min.value(0), "apple"); + + // Check maxValues.nested.a = 20, maxValues.nested.b = "zebra" + let max_values = stats + .column_by_name("maxValues") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let nested_max = max_values + .column_by_name("nested") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let a_max = nested_max + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_max.value(0), 20); + + let b_max = nested_max + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_max.value(0), "zebra"); + } +} diff --git a/kernel/src/engine/sync/parquet.rs b/kernel/src/engine/sync/parquet.rs index 557ae9b226..1c0db5d72a 100644 --- a/kernel/src/engine/sync/parquet.rs +++ b/kernel/src/engine/sync/parquet.rs @@ -79,6 +79,7 @@ impl ParquetHandler for SyncParquetHandler { /// - `location` - The full URL path where the Parquet file should be written /// (e.g., `file:///path/to/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Optional column names for which statistics should be collected. /// /// # Returns /// @@ -87,6 +88,7 @@ impl ParquetHandler for SyncParquetHandler { &self, location: Url, mut data: Box>> + Send>, + _stats_columns: Option<&[String]>, ) -> DeltaResult<()> { // Convert URL to file path let path = location @@ -115,6 +117,7 @@ impl ParquetHandler for SyncParquetHandler { writer.close()?; // writer must be closed to write footer + // TODO: Implement stats collection for SyncEngine Ok(()) } @@ -174,7 +177,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data))); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, None) + .unwrap(); // Verify the file exists assert!(file_path.exists()); @@ -295,7 +300,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data))); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, None) + .unwrap(); // Verify the file exists assert!(file_path.exists()); @@ -370,7 +377,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data1))); // Write the first file - handler.write_parquet_file(url.clone(), data_iter1).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter1, None) + .unwrap(); assert!(file_path.exists()); // Create second data set with different data @@ -386,7 +395,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data2))); // Overwrite with second file (overwrite=true) - handler.write_parquet_file(url.clone(), data_iter2).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter2, None) + .unwrap(); // Read back and verify it contains the second data set let file = File::open(&file_path).unwrap(); @@ -445,7 +456,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data1))); // Write the first file - handler.write_parquet_file(url.clone(), data_iter1).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter1, None) + .unwrap(); assert!(file_path.exists()); // Create second data set @@ -461,7 +474,9 @@ mod tests { > = Box::new(std::iter::once(Ok(engine_data2))); // Write again - should overwrite successfully (new behavior always overwrites) - handler.write_parquet_file(url.clone(), data_iter2).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter2, None) + .unwrap(); // Verify the file was overwritten with the new data let file = File::open(&file_path).unwrap(); @@ -537,7 +552,9 @@ mod tests { > = Box::new(batches.into_iter()); // Write the file - handler.write_parquet_file(url.clone(), data_iter).unwrap(); + handler + .write_parquet_file(url.clone(), data_iter, None) + .unwrap(); // Verify the file exists assert!(file_path.exists()); diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index d27a1c915a..72a0f8dd0a 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -778,9 +778,10 @@ pub trait ParquetHandler: AsAny { predicate: Option, ) -> DeltaResult; - /// Write data to a Parquet file at the specified URL. + /// Write data to a Parquet file at the specified URL, collecting statistics. /// - /// This method writes the provided `data` to a Parquet file at the given `url`. + /// This method writes the provided `data` to a Parquet file at the given `url`, + /// and collects statistics (min, max, null count) for the specified columns. /// /// This will overwrite the file if it already exists. /// @@ -789,6 +790,7 @@ pub trait ParquetHandler: AsAny { /// - `url` - The full URL path where the Parquet file should be written /// (e.g., `s3://bucket/path/file.parquet`). /// - `data` - An iterator of engine data to be written to the Parquet file. + /// - `stats_columns` - Optional column names for which statistics should be collected. /// /// # Returns /// @@ -797,6 +799,7 @@ pub trait ParquetHandler: AsAny { &self, location: url::Url, data: Box>> + Send>, + stats_columns: Option<&[String]>, ) -> DeltaResult<()>; /// Read the footer metadata from a Parquet file without reading the data. diff --git a/kernel/src/scan/data_skipping/stats_schema.rs b/kernel/src/scan/data_skipping/stats_schema.rs index e18423b467..ff4fde8b69 100644 --- a/kernel/src/scan/data_skipping/stats_schema.rs +++ b/kernel/src/scan/data_skipping/stats_schema.rs @@ -133,6 +133,120 @@ pub(crate) fn expected_stats_schema( StructType::try_new(fields) } +/// Returns the list of column names that should have statistics collected. +/// +/// This extracts just the column names without building the full stats schema, +/// making it more efficient when only the column list is needed. +#[allow(unused)] +pub(crate) fn stats_column_names( + physical_file_schema: &Schema, + table_properties: &TableProperties, +) -> Vec { + let mut filter = StatsColumnFilter::new(table_properties); + let mut columns = Vec::new(); + filter.collect_columns(physical_file_schema, &mut columns); + columns +} + +/// Handles column filtering logic for statistics based on table properties. +/// +/// Filters columns according to: +/// * `dataSkippingStatsColumns` - explicit list of columns to include (takes precedence) +/// * `dataSkippingNumIndexedCols` - number of leaf columns to include (default 32) +struct StatsColumnFilter { + n_columns: Option, + added_columns: u64, + column_names: Option>, + path: Vec, +} + +impl StatsColumnFilter { + fn new(props: &TableProperties) -> Self { + // If data_skipping_stats_columns is specified, it takes precedence + // over data_skipping_num_indexed_cols, even if that is also specified. + if let Some(column_names) = &props.data_skipping_stats_columns { + Self { + n_columns: None, + added_columns: 0, + column_names: Some(column_names.clone()), + path: Vec::new(), + } + } else { + let n_cols = props + .data_skipping_num_indexed_cols + .unwrap_or(DataSkippingNumIndexedCols::NumColumns(32)); + Self { + n_columns: Some(n_cols), + added_columns: 0, + column_names: None, + path: Vec::new(), + } + } + } + + /// Collects column names that should have statistics. + fn collect_columns(&mut self, schema: &Schema, result: &mut Vec) { + for field in schema.fields() { + self.collect_field(field, result); + } + } + + fn collect_field(&mut self, field: &StructField, result: &mut Vec) { + if self.at_column_limit() { + return; + } + + self.path.push(field.name.clone()); + + match field.data_type() { + DataType::Struct(struct_type) => { + for child in struct_type.fields() { + self.collect_field(child, result); + } + } + _ => { + if self.should_include_current() { + result.push(ColumnName::new(&self.path)); + self.added_columns += 1; + } + } + } + + self.path.pop(); + } + + /// Returns true if the column limit has been reached. + fn at_column_limit(&self) -> bool { + matches!( + self.n_columns, + Some(DataSkippingNumIndexedCols::NumColumns(n)) if self.added_columns >= n + ) + } + + /// Returns true if the current path should be included based on column_names config. + fn should_include_current(&self) -> bool { + self.column_names + .as_ref() + .map(|ns| should_include_column(&ColumnName::new(&self.path), ns)) + .unwrap_or(true) + } + + /// Enters a field path for filtering decisions. + fn enter_field(&mut self, name: &str) { + self.path.push(name.to_string()); + } + + /// Exits the current field path. + fn exit_field(&mut self) { + self.path.pop(); + } + + /// Records that a leaf column was included. + fn record_included(&mut self) { + self.added_columns += 1; + } +} + /// Transforms a schema to make all fields nullable. /// Used for stats schemas where stats may not be available for all columns. pub(crate) struct NullableStatsTransform; @@ -178,44 +292,19 @@ impl<'a> SchemaTransform<'a> for NullCountStatsTransform { /// Base stats schema in this case refers the subsets of fields in the table schema /// that may be considered for stats collection. Depending on the type of stats - min/max/nullcount/... - /// additional transformations may be applied. +/// Transforms a schema to filter columns for statistics based on table properties. /// -/// The concrete shape of the schema depends on the table configuration. -/// * `dataSkippingStatsColumns` - used to explicitly specify the columns -/// to be used for data skipping statistics. (takes precedence) -/// * `dataSkippingNumIndexedCols` - used to specify the number of columns -/// to be used for data skipping statistics. Defaults to 32. -/// -/// All fields are nullable. +/// All fields in the output are nullable. #[allow(unused)] struct BaseStatsTransform { - n_columns: Option, - added_columns: u64, - column_names: Option>, - path: Vec, + filter: StatsColumnFilter, } impl BaseStatsTransform { #[allow(unused)] fn new(props: &TableProperties) -> Self { - // If data_skipping_stats_columns is specified, it takes precedence - // over data_skipping_num_indexed_cols, even if that is also specified. - if let Some(column_names) = &props.data_skipping_stats_columns { - Self { - n_columns: None, - added_columns: 0, - column_names: Some(column_names.clone()), - path: Vec::new(), - } - } else { - let n_cols = props - .data_skipping_num_indexed_cols - .unwrap_or(DataSkippingNumIndexedCols::NumColumns(32)); - Self { - n_columns: Some(n_cols), - added_columns: 0, - column_names: None, - path: Vec::new(), - } + Self { + filter: StatsColumnFilter::new(props), } } } @@ -224,34 +313,22 @@ impl<'a> SchemaTransform<'a> for BaseStatsTransform { fn transform_struct_field(&mut self, field: &'a StructField) -> Option> { use Cow::*; - // Check if the number of columns is set and if the added columns exceed the limit - // In the constructor we assert this will always be None if column_names are specified - if let Some(DataSkippingNumIndexedCols::NumColumns(n_cols)) = self.n_columns { - if self.added_columns >= n_cols { - return None; - } + if self.filter.at_column_limit() { + return None; } - self.path.push(field.name.clone()); + self.filter.enter_field(field.name()); let data_type = field.data_type(); // We always traverse struct fields (they don't count against the column limit), // but we only include leaf fields if they qualify based on column_names config. // When column_names is None, all leaf fields are included (up to n_columns limit). if !matches!(data_type, DataType::Struct(_)) { - let should_include = self - .column_names - .as_ref() - .map(|ns| should_include_column(&ColumnName::new(&self.path), ns)) - .unwrap_or(true); - - if !should_include { - self.path.pop(); + if !self.filter.should_include_current() { + self.filter.exit_field(); return None; } - - // Increment count only for leaf columns - self.added_columns += 1; + self.filter.record_included(); } let field = match self.transform(&field.data_type)? { @@ -264,7 +341,7 @@ impl<'a> SchemaTransform<'a> for BaseStatsTransform { }), }; - self.path.pop(); + self.filter.exit_field(); // exclude struct fields with no children if matches!(field.data_type(), DataType::Struct(dt) if dt.fields().len() == 0) { diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 195b64f31d..6bb9669a74 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -441,9 +441,11 @@ impl Snapshot { let data_iter = writer.checkpoint_data(engine)?; let state = data_iter.state(); let lazy_data = data_iter.map(|r| r.and_then(|f| f.apply_selection_vector())); - engine - .parquet_handler() - .write_parquet_file(checkpoint_path.clone(), Box::new(lazy_data))?; + engine.parquet_handler().write_parquet_file( + checkpoint_path.clone(), + Box::new(lazy_data), + None, + )?; let file_meta = engine.storage_handler().head(&checkpoint_path)?; diff --git a/kernel/src/table_configuration.rs b/kernel/src/table_configuration.rs index 9080447084..cbbc8f6738 100644 --- a/kernel/src/table_configuration.rs +++ b/kernel/src/table_configuration.rs @@ -13,7 +13,8 @@ use std::sync::Arc; use url::Url; use crate::actions::{Metadata, Protocol}; -use crate::scan::data_skipping::stats_schema::expected_stats_schema; +use crate::expressions::ColumnName; +use crate::scan::data_skipping::stats_schema::{expected_stats_schema, stats_column_names}; use crate::schema::variant_utils::validate_variant_type_feature_support; use crate::schema::{InvariantChecker, SchemaRef, StructType}; use crate::table_features::{ @@ -181,21 +182,38 @@ impl TableConfiguration { #[allow(unused)] #[internal_api] pub(crate) fn expected_stats_schema(&self) -> DeltaResult { + let physical_schema = self.physical_data_schema(); + Ok(Arc::new(expected_stats_schema( + &physical_schema, + self.table_properties(), + )?)) + } + + /// Returns the list of column names that should have statistics collected. + /// + /// This returns the leaf column paths as a flat list of column names + /// (e.g., `["id", "nested.field"]`). + #[allow(unused)] + #[internal_api] + pub(crate) fn stats_column_names(&self) -> Vec { + let physical_schema = self.physical_data_schema(); + stats_column_names(&physical_schema, self.table_properties()) + } + + /// Returns the physical schema for data columns (excludes partition columns). + /// + /// Partition columns are excluded because statistics are only collected for data columns + /// that are physically stored in the parquet files. Partition values are stored in the + /// file path, not in the file content, so they don't have file-level statistics. + fn physical_data_schema(&self) -> StructType { let partition_columns = self.metadata().partition_columns(); let column_mapping_mode = self.column_mapping_mode(); - // Partition columns are excluded because statistics are only collected for data columns - // that are physically stored in the parquet files. Partition values are stored in the - // file path, not in the file content, so they don't have file-level statistics. - let physical_schema = StructType::try_new( + StructType::new_unchecked( self.schema() .fields() .filter(|field| !partition_columns.contains(field.name())) .map(|field| field.make_physical(column_mapping_mode)), - )?; - Ok(Arc::new(expected_stats_schema( - &physical_schema, - self.table_properties(), - )?)) + ) } /// The [`Metadata`] for this table at this version. diff --git a/kernel/src/transaction/mod.rs b/kernel/src/transaction/mod.rs index a3041b1893..77be83585b 100644 --- a/kernel/src/transaction/mod.rs +++ b/kernel/src/transaction/mod.rs @@ -790,6 +790,45 @@ impl Transaction { &BASE_ADD_FILES_SCHEMA } + /// Returns the expected schema for file statistics. + /// + /// The schema structure is derived from table configuration: + /// - `delta.dataSkippingStatsColumns`: Explicit column list (if set) + /// - `delta.dataSkippingNumIndexedCols`: Column count limit (default 32) + /// - Partition columns: Always excluded + /// + /// The returned schema has the following structure: + /// ```ignore + /// { + /// numRecords: long, + /// nullCount: { ... }, // Nested struct mirroring data schema, all fields LONG + /// minValues: { ... }, // Nested struct, only min/max eligible types + /// maxValues: { ... }, // Nested struct, only min/max eligible types + /// tightBounds: boolean, + /// } + /// ``` + /// + /// Engines should collect statistics matching this schema structure when writing files. + #[allow(unused)] + pub fn stats_schema(&self) -> DeltaResult { + self.read_snapshot + .table_configuration() + .expected_stats_schema() + } + + /// Returns the list of column names that should have statistics collected. + /// + /// This returns the leaf column paths as a flat list of column names + /// (e.g., `["id", "nested.field"]`). + /// + /// Engines can use this to determine which columns need stats during writes. + #[allow(unused)] + pub fn stats_columns(&self) -> Vec { + self.read_snapshot + .table_configuration() + .stats_column_names() + } + // Generate the logical-to-physical transform expression which must be evaluated on every data // chunk before writing. At the moment, this is a transaction-wide expression. fn generate_logical_to_physical(&self) -> Expression { @@ -820,6 +859,8 @@ impl Transaction { // Note: after we introduce metadata updates (modify table schema, etc.), we need to make sure // that engines cannot call this method after a metadata change, since the write context could // have invalid metadata. + // Note: Callers that use get_write_context may be writing data to the table and they might + // have invalid metadata. pub fn get_write_context(&self) -> WriteContext { let target_dir = self.read_snapshot.table_root(); let snapshot_schema = self.read_snapshot.schema(); @@ -837,11 +878,19 @@ impl Transaction { .cloned(); let physical_schema = Arc::new(StructType::new_unchecked(physical_fields)); + // Get stats columns from table configuration + let stats_columns = self + .stats_columns() + .into_iter() + .map(|c| c.to_string()) + .collect(); + WriteContext::new( target_dir.clone(), snapshot_schema, physical_schema, Arc::new(logical_to_physical), + stats_columns, ) } @@ -1313,6 +1362,8 @@ pub struct WriteContext { logical_schema: SchemaRef, physical_schema: SchemaRef, logical_to_physical: ExpressionRef, + /// Column names that should have statistics collected during writes. + stats_columns: Vec, } impl WriteContext { @@ -1321,12 +1372,14 @@ impl WriteContext { logical_schema: SchemaRef, physical_schema: SchemaRef, logical_to_physical: ExpressionRef, + stats_columns: Vec, ) -> Self { WriteContext { target_dir, logical_schema, physical_schema, logical_to_physical, + stats_columns, } } @@ -1346,6 +1399,13 @@ impl WriteContext { self.logical_to_physical.clone() } + /// Returns the column names that should have statistics collected during writes. + /// + /// Based on table configuration (dataSkippingNumIndexedCols, dataSkippingStatsColumns). + pub fn stats_columns(&self) -> &[String] { + &self.stats_columns + } + /// Generate a new unique absolute URL for a deletion vector file. /// /// This method generates a unique file name in the table directory. diff --git a/kernel/tests/write.rs b/kernel/tests/write.rs index f47db4b92e..66a3c4cef8 100644 --- a/kernel/tests/write.rs +++ b/kernel/tests/write.rs @@ -393,7 +393,7 @@ async fn test_append() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": true, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":1},\"maxValues\":{\"number\":3},\"tightBounds\":true}" } }), json!({ @@ -403,7 +403,7 @@ async fn test_append() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": true, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":4},\"maxValues\":{\"number\":6},\"tightBounds\":true}" } }), ]; @@ -601,7 +601,7 @@ async fn test_append_partitioned() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": false, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":1},\"maxValues\":{\"number\":3},\"tightBounds\":true}" } }), json!({ @@ -613,7 +613,7 @@ async fn test_append_partitioned() -> Result<(), Box> { "size": size, "modificationTime": 0, "dataChange": false, - "stats": "{\"numRecords\":3}" + "stats": "{\"numRecords\":3,\"nullCount\":{\"number\":0},\"minValues\":{\"number\":4},\"maxValues\":{\"number\":6},\"tightBounds\":true}" } }), ]; @@ -1078,6 +1078,7 @@ async fn test_append_variant() -> Result<(), Box> { write_context.target_dir(), Box::new(ArrowEngineData::new(data.clone())), HashMap::new(), + Some(write_context.stats_columns()), ) .await?; @@ -1251,6 +1252,7 @@ async fn test_shredded_variant_read_rejection() -> Result<(), Box