Skip to content

Commit d6e9c4e

Browse files
committed
feat: arrow convenience extensions
1 parent e74d18b commit d6e9c4e

File tree

7 files changed

+216
-22
lines changed

7 files changed

+216
-22
lines changed

kernel/examples/read-table-single-threaded/src/main.rs

+2-21
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@ use std::collections::HashMap;
22
use std::process::ExitCode;
33
use std::sync::Arc;
44

5-
use arrow::compute::filter_record_batch;
6-
use arrow::record_batch::RecordBatch;
75
use arrow::util::pretty::print_batches;
8-
use delta_kernel::engine::arrow_data::ArrowEngineData;
6+
use delta_kernel::engine::arrow_extensions::ScanExt;
97
use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor;
108
use delta_kernel::engine::default::DefaultEngine;
119
use delta_kernel::engine::sync::SyncEngine;
@@ -119,24 +117,7 @@ fn try_main() -> DeltaResult<()> {
119117
.with_schema_opt(read_schema_opt)
120118
.build()?;
121119

122-
let batches: Vec<RecordBatch> = scan
123-
.execute(engine)?
124-
.map(|scan_result| -> DeltaResult<_> {
125-
let scan_result = scan_result?;
126-
let mask = scan_result.full_mask();
127-
let data = scan_result.raw_data?;
128-
let record_batch: RecordBatch = data
129-
.into_any()
130-
.downcast::<ArrowEngineData>()
131-
.map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))?
132-
.into();
133-
if let Some(mask) = mask {
134-
Ok(filter_record_batch(&record_batch, &mask.into())?)
135-
} else {
136-
Ok(record_batch)
137-
}
138-
})
139-
.try_collect()?;
120+
let batches: Vec<_> = scan.execute_arrow(engine)?.try_collect()?;
140121
print_batches(&batches)?;
141122
Ok(())
142123
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use crate::arrow::record_batch::RecordBatch;
2+
3+
use crate::{DeltaResult, ExpressionEvaluator};
4+
5+
use super::super::arrow_data::ArrowEngineData;
6+
7+
pub trait ExpressionEvaluatorExt {
8+
fn evaluate_arrow(&self, batch: RecordBatch) -> DeltaResult<RecordBatch>;
9+
}
10+
11+
impl<T: ExpressionEvaluator + ?Sized> ExpressionEvaluatorExt for T {
12+
fn evaluate_arrow(&self, batch: RecordBatch) -> DeltaResult<RecordBatch> {
13+
let engine_data = ArrowEngineData::new(batch);
14+
Ok(ArrowEngineData::try_from_engine_data(T::evaluate(&self, &engine_data)?)?.into())
15+
}
16+
}
17+
18+
#[cfg(test)]
19+
mod tests {
20+
use std::sync::Arc;
21+
22+
use super::ExpressionEvaluatorExt;
23+
24+
use crate::arrow::array::Int32Array;
25+
use crate::arrow::datatypes::{DataType, Field, Schema};
26+
use crate::arrow::record_batch::RecordBatch;
27+
use crate::engine::arrow_expression::ArrowEvaluationHandler;
28+
use crate::expressions::*;
29+
use crate::EvaluationHandler;
30+
31+
#[test_log::test]
32+
fn test_evaluate_arrow() {
33+
let handler = ArrowEvaluationHandler;
34+
35+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
36+
let values = Int32Array::from(vec![1, 2, 3]);
37+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
38+
39+
let expression = column_expr!("a");
40+
let expr = handler.new_expression_evaluator(
41+
Arc::new((&schema).try_into().unwrap()),
42+
expression,
43+
crate::schema::DataType::INTEGER,
44+
);
45+
46+
let result = expr.evaluate_arrow(batch);
47+
assert!(result.is_ok());
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod evaluator;
2+
mod scan;
3+
4+
pub use evaluator::*;
5+
pub use scan::*;
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use std::sync::Arc;
2+
3+
use crate::arrow::array::BooleanArray;
4+
use crate::arrow::compute::filter_record_batch;
5+
use crate::arrow::record_batch::RecordBatch;
6+
use itertools::Itertools;
7+
8+
use crate::scan::{Scan, ScanMetadata, ScanResult};
9+
use crate::{DeltaResult, Engine, Error, ExpressionRef};
10+
11+
use super::super::arrow_data::ArrowEngineData;
12+
13+
/// [`ScanMetadata`] contains (1) a [`RecordBatch`] specifying data files to be scanned
14+
/// and (2) a vector of transforms (one transform per scan file) that must be applied to the data read
15+
/// from those files.
16+
pub struct ScanMetadataArrow {
17+
/// Record batch with one row per file to scan
18+
pub scan_files: RecordBatch,
19+
20+
/// Row-level transformations to apply to data read from files.
21+
///
22+
/// Each entry in this vector corresponds to a row in the `scan_files` data. The entry is an
23+
/// expression that must be applied to convert the file's data into the logical schema
24+
/// expected by the scan:
25+
///
26+
/// - `Some(expr)`: Apply this expression to transform the data to match [`Scan::schema()`].
27+
/// - `None`: No transformation is needed; the data is already in the correct logical form.
28+
///
29+
/// Note: This vector can be indexed by row number.
30+
pub scan_file_transforms: Vec<Option<ExpressionRef>>,
31+
}
32+
33+
impl TryFrom<ScanMetadata> for ScanMetadataArrow {
34+
type Error = Error;
35+
36+
fn try_from(metadata: ScanMetadata) -> Result<Self, Self::Error> {
37+
let scan_file_transforms = metadata
38+
.scan_file_transforms
39+
.into_iter()
40+
.enumerate()
41+
.filter_map(|(i, v)| metadata.scan_files.selection_vector[i].then_some(v))
42+
.collect();
43+
let batch = ArrowEngineData::try_from_engine_data(metadata.scan_files.data)?.into();
44+
let scan_files = filter_record_batch(
45+
&batch,
46+
&BooleanArray::from(metadata.scan_files.selection_vector),
47+
)?;
48+
Ok(ScanMetadataArrow {
49+
scan_files,
50+
scan_file_transforms,
51+
})
52+
}
53+
}
54+
55+
impl TryFrom<ScanResult> for RecordBatch {
56+
type Error = Error;
57+
58+
fn try_from(result: ScanResult) -> Result<Self, Self::Error> {
59+
let (mask, data) = (result.full_mask(), result.raw_data?);
60+
let record_batch = ArrowEngineData::try_from_engine_data(data)?.into();
61+
mask.map(|m| Ok(filter_record_batch(&record_batch, &m.into())?))
62+
.unwrap_or(Ok(record_batch))
63+
}
64+
}
65+
66+
pub trait ScanExt {
67+
fn scan_metadata_arrow(
68+
&self,
69+
engine: &dyn Engine,
70+
) -> DeltaResult<impl Iterator<Item = DeltaResult<ScanMetadataArrow>>>;
71+
72+
fn execute_arrow(
73+
&self,
74+
engine: Arc<dyn Engine>,
75+
) -> DeltaResult<impl Iterator<Item = DeltaResult<RecordBatch>>>;
76+
}
77+
78+
impl ScanExt for Scan {
79+
fn scan_metadata_arrow(
80+
&self,
81+
engine: &dyn Engine,
82+
) -> DeltaResult<impl Iterator<Item = DeltaResult<ScanMetadataArrow>>> {
83+
Ok(self
84+
.scan_metadata(engine)?
85+
.map_ok(TryFrom::try_from)
86+
.flatten())
87+
}
88+
89+
fn execute_arrow(
90+
&self,
91+
engine: Arc<dyn Engine>,
92+
) -> DeltaResult<impl Iterator<Item = DeltaResult<RecordBatch>>> {
93+
Ok(self.execute(engine)?.map_ok(TryFrom::try_from).flatten())
94+
}
95+
}

kernel/src/engine/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub mod sync;
2222
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
2323
pub mod arrow_data;
2424
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
25+
pub mod arrow_extensions;
26+
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
2527
pub(crate) mod arrow_get_data;
2628
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
2729
pub(crate) mod ensure_data_types;

kernel/tests/arrow_extensions.rs

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
use std::path::PathBuf;
2+
use std::sync::Arc;
3+
4+
use delta_kernel::engine::arrow_extensions::ScanExt;
5+
use delta_kernel::engine::sync::SyncEngine;
6+
use delta_kernel::Table;
7+
use itertools::Itertools;
8+
9+
mod common;
10+
11+
#[test_log::test]
12+
fn test_scan_metadata_arrow() {
13+
let path =
14+
std::fs::canonicalize(PathBuf::from("./tests/data/table-without-dv-small/")).unwrap();
15+
let url = url::Url::from_directory_path(path).unwrap();
16+
let engine = Arc::new(SyncEngine::new());
17+
18+
let table = Table::new(url);
19+
let snapshot = table.snapshot(engine.as_ref(), None).unwrap();
20+
let scan = snapshot.into_scan_builder().build().unwrap();
21+
let files: Vec<_> = scan
22+
.scan_metadata_arrow(engine.as_ref())
23+
.unwrap()
24+
.try_collect()
25+
.unwrap();
26+
27+
assert_eq!(files.len(), 1);
28+
let num_rows = files[0].scan_files.num_rows();
29+
assert_eq!(num_rows, 1)
30+
}
31+
32+
#[test_log::test]
33+
fn test_execute_arrow() {
34+
let path =
35+
std::fs::canonicalize(PathBuf::from("./tests/data/table-without-dv-small/")).unwrap();
36+
let url = url::Url::from_directory_path(path).unwrap();
37+
let engine = Arc::new(SyncEngine::new());
38+
39+
let table = Table::new(url);
40+
let snapshot = table.snapshot(engine.as_ref(), None).unwrap();
41+
let scan = snapshot.into_scan_builder().build().unwrap();
42+
let files: Vec<_> = scan.execute_arrow(engine).unwrap().try_collect().unwrap();
43+
44+
let expected = vec![
45+
"+-------+",
46+
"| value |",
47+
"+-------+",
48+
"| 0 |",
49+
"| 1 |",
50+
"| 2 |",
51+
"| 3 |",
52+
"| 4 |",
53+
"| 5 |",
54+
"| 6 |",
55+
"| 7 |",
56+
"| 8 |",
57+
"| 9 |",
58+
"+-------+",
59+
];
60+
61+
assert_batches_sorted_eq!(expected, &files);
62+
}

kernel/tests/common/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use delta_kernel::arrow::record_batch::RecordBatch;
33
use delta_kernel::arrow::util::pretty::pretty_format_batches;
44
use itertools::Itertools;
55

6-
use crate::ArrowEngineData;
6+
use delta_kernel::engine::arrow_data::ArrowEngineData;
77
use delta_kernel::scan::Scan;
88
use delta_kernel::{DeltaResult, Engine, EngineData, Table};
99

0 commit comments

Comments
 (0)