diff --git a/kernel/src/engine/default/parquet.rs b/kernel/src/engine/default/parquet.rs index 8636b3d9f8..88dc9b72ab 100644 --- a/kernel/src/engine/default/parquet.rs +++ b/kernel/src/engine/default/parquet.rs @@ -6,6 +6,8 @@ use std::sync::Arc; use crate::arrow::array::builder::{MapBuilder, MapFieldNames, StringBuilder}; use crate::arrow::array::{BooleanArray, Int64Array, RecordBatch, StringArray}; +#[cfg(test)] +use crate::engine_data::FilteredEngineData; use crate::parquet::arrow::arrow_reader::{ ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, }; @@ -110,40 +112,47 @@ impl DefaultParquetHandler { self } - // Write `data` to `{path}/.parquet` as parquet using ArrowWriter and return the parquet - // metadata (where `` is a generated UUIDv4). + // Writes an iterator of `EngineData` batches to a parquet file at the specified path and + // returns the parquet metadata as `DataFileMetadata`. // - // Note: after encoding the data as parquet, this issues a PUT followed by a HEAD to storage in + // Notes: + // 1. After encoding the data as parquet, this issues a PUT followed by a HEAD to storage in // order to obtain metadata about the object just written. - async fn write_parquet( + // 2. The schema of all batches must be the same. + async fn write_parquet_from_batches( &self, - path: &url::Url, - data: Box, + path: url::Url, + mut data_iter: impl Iterator>>, ) -> DeltaResult { - let batch: Box<_> = ArrowEngineData::try_from_engine_data(data)?; - let record_batch = batch.record_batch(); - + // Get the first batch to extract schema + let first_data = data_iter + .next() + .ok_or_else(|| Error::generic("No data to write"))??; + let arrow_engine_data = ArrowEngineData::try_from_engine_data(first_data)?; + let first_batch = arrow_engine_data.record_batch(); + let schema = first_batch.schema(); + + // Setup buffer and writer with compression let mut buffer = vec![]; - let mut writer = ArrowWriter::try_new(&mut buffer, record_batch.schema(), None)?; - writer.write(record_batch)?; + let mut writer = ArrowWriter::try_new(&mut buffer, schema, None)?; + + // Write the batches to buffer + writer.write(first_batch)?; + for batch_result in data_iter { + let data = batch_result?; + let batch = ArrowEngineData::try_from_engine_data(data)?; + let record_batch = batch.record_batch(); + writer.write(record_batch)?; + } + writer.close()?; // writer must be closed to write footer let size = buffer.len(); - let name: String = format!("{}.parquet", Uuid::new_v4()); - // fail if path does not end with a trailing slash - if !path.path().ends_with('/') { - return Err(Error::generic(format!( - "Path must end with a trailing slash: {}", - path - ))); - } - let path = path.join(&name)?; + let file_path = Path::from(path.path()); + self.store.put(&file_path, buffer.into()).await?; - self.store - .put(&Path::from(path.path()), buffer.into()) - .await?; - - let metadata = self.store.head(&Path::from(path.path())).await?; + // Get the metadata of the written file + let metadata = self.store.head(&file_path).await?; let modification_time = metadata.last_modified.timestamp_millis(); if size != metadata.size { return Err(Error::generic(format!( @@ -156,6 +165,29 @@ impl DefaultParquetHandler { Ok(DataFileMetadata::new(file_meta)) } + // Write `data` to `{path}/.parquet` as parquet using ArrowWriter and return the parquet + // metadata (where `` is a generated UUIDv4). + // + // This function is a convenience wrapper around `write_parquet_from_batches` that creates a + // unique filename for the parquet file. The path must end with a trailing slash. + async fn write_parquet( + &self, + path: &url::Url, + data: Box, + ) -> DeltaResult { + // fail if path does not end with a trailing slash + if !path.path().ends_with('/') { + return Err(Error::generic(format!( + "Path must end with a trailing slash: {}", + path + ))); + } + let name: String = format!("{}.parquet", Uuid::new_v4()); + let path = path.join(&name)?; + self.write_parquet_from_batches(path, std::iter::once(Ok(data))) + .await + } + /// Write `data` to `{path}/.parquet` as parquet using ArrowWriter and return the parquet /// metadata as an EngineData batch which matches the [write metadata] schema (where `` is /// a generated UUIDv4). @@ -171,6 +203,59 @@ impl DefaultParquetHandler { let parquet_metadata = self.write_parquet(path, data).await?; parquet_metadata.as_record_batch(&partition_values, data_change) } + + /// Writes filtered data to a parquet file at the specified path. + /// + /// This function takes an iterator of [`FilteredEngineData`] objects, each containing a batch of + /// [`EngineData`] and a selection vector indicating which rows to keep. It filters each batch + /// according to its selection vector and writes all retained rows to a single parquet file at the + /// specified path. + /// + /// Note: The selection vector must match the number of rows in the corresponding batch. + #[cfg(test)] + async fn write_parquet_from_filtered_batches( + &self, + path: url::Url, + filtered_data: impl Iterator, + ) -> DeltaResult { + use arrow_53::compute::filter; + use itertools::Itertools; + + // Process each `FilteredEngineData` item, transforming each into filtered `EngineData` + let data = filtered_data.map(|batch| -> DeltaResult> { + let arrow_engine_data = ArrowEngineData::try_from_engine_data(batch.data)?; + let record_batch = arrow_engine_data.record_batch(); + + // Check that the selection vector length matches the number of rows + if batch.selection_vector.len() != record_batch.num_rows() { + return Err(Error::generic(format!( + "Mask length ({}) doesn't match number of rows ({})", + batch.selection_vector.len(), + record_batch.num_rows() + ))); + } + + // Create arrow array from selection vector + let selection_array = Arc::new(BooleanArray::from(batch.selection_vector)); + + // Filter each column in the record batch using the selection arrow array + let filtered_columns: Vec<_> = record_batch + .columns() + .iter() + .map(|col| { + filter(col.as_ref(), &selection_array) + .map_err(|e| Error::generic(format!("Error filtering column: {e}"))) + }) + .try_collect()?; + + // Create filtered record batch and convert back to EngineData + let filtered_batch = RecordBatch::try_new(record_batch.schema(), filtered_columns)?; + Ok(Box::new(ArrowEngineData::new(filtered_batch))) + }); + + // Write the filtered data to parquet + self.write_parquet_from_batches(path, data).await + } } impl ParquetHandler for DefaultParquetHandler { @@ -383,6 +468,16 @@ mod tests { .map(Into::into) } + fn make_batch(ids: Vec, names: Vec<&str>) -> Box { + Box::new(ArrowEngineData::new( + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int64Array::from(ids)) as Arc), + ("name", Arc::new(StringArray::from(names)) as Arc), + ]) + .unwrap(), + )) + } + #[tokio::test] async fn test_read_parquet_files() { let store = Arc::new(LocalFileSystem::new()); @@ -475,13 +570,7 @@ mod tests { let parquet_handler = DefaultParquetHandler::new(store.clone(), Arc::new(TokioBackgroundExecutor::new())); - let data = Box::new(ArrowEngineData::new( - RecordBatch::try_from_iter(vec![( - "a", - Arc::new(Int64Array::from(vec![1, 2, 3])) as Arc, - )]) - .unwrap(), - )); + let data = make_batch(vec![1, 2, 3], vec!["a", "b", "c"]); let write_metadata = parquet_handler .write_parquet(&Url::parse("memory:///data/").unwrap(), data) @@ -559,4 +648,70 @@ mod tests { .await .is_err()); } + + #[tokio::test] + async fn test_write_multiple_filtered_parquet() -> DeltaResult<()> { + let store = Arc::new(InMemory::new()); + let parquet_handler = + DefaultParquetHandler::new(store.clone(), Arc::new(TokioBackgroundExecutor::new())); + + let filtered_data = vec![ + FilteredEngineData::new( + make_batch(vec![1, 2, 3], vec!["a", "b", "c"]), + vec![true, false, true], // Keep rows 0 and 2 from batch1 + ), + FilteredEngineData::new( + make_batch(vec![4, 5, 6], vec!["d", "e", "f"]), + vec![false, true, false], // Keep row 1 from batch2 + ), + ]; + + // Write the filtered data + let file_meta = parquet_handler + .write_parquet_from_filtered_batches( + Url::parse("memory:///multiple_filtered_data/")?, + filtered_data.into_iter(), + ) + .await?; + + // Read back the written data + let schema = Arc::new( + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int64Array::from(vec![0])) as Arc), + ( + "name", + Arc::new(StringArray::from(vec![""])) as Arc, + ), + ])? + .schema() + .try_into()?, + ); + + let data: Vec = parquet_handler + .read_parquet_files(&[file_meta.file_meta.clone()], schema, None)? + .map(into_record_batch) + .try_collect()?; + + // Verify that we have the expected number of results + assert_eq!(data.len(), 1); + assert_eq!(data[0].num_rows(), 3); + + // Get the values from the returned arrays + let ids = data[0].column(0).as_any().downcast_ref::(); + let names = data[0].column(1).as_any().downcast_ref::(); + + // Verify the combined values + assert_eq!( + ids.unwrap().values(), + &[1, 3, 5], + "Should have the expected ID values" + ); + assert_eq!( + names.unwrap().iter().collect::>(), + vec![Some("a"), Some("c"), Some("e")], + "Should have the expected name values" + ); + + Ok(()) + } } diff --git a/kernel/src/engine_data.rs b/kernel/src/engine_data.rs index 44ada91e78..f0618350ac 100644 --- a/kernel/src/engine_data.rs +++ b/kernel/src/engine_data.rs @@ -21,6 +21,15 @@ pub struct FilteredEngineData { pub selection_vector: Vec, } +impl FilteredEngineData { + #[cfg(test)] + pub(crate) fn new(data: Box, selection_vector: Vec) -> Self { + FilteredEngineData { + data, + selection_vector, + } + } +} impl HasSelectionVector for FilteredEngineData { /// Returns true if any row in the selection vector is marked as selected fn has_selected_rows(&self) -> bool {