Skip to content

Commit

Permalink
feat: ability to write expressions that access map values and express…
Browse files Browse the repository at this point in the history
…ion evaluation on that map access
  • Loading branch information
hntd187 committed Sep 23, 2024
1 parent 27b414b commit d2907d9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 75 deletions.
110 changes: 96 additions & 14 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Expression handling based on arrow-rs compute kernels.
use std::sync::Arc;

use arrow_arith::boolean::{and_kleene, is_not_null, is_null, not, or_kleene};
use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene};
use arrow_arith::numeric::{add, div, mul, sub};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
Expand Down Expand Up @@ -393,13 +393,14 @@ pub(crate) fn evaluate_expression(
}
(MapAccess { source, key }, _) => {
let source = evaluate_expression(source, batch, None)?;
let entries = source.as_map();
let entry_mask = is_not_null(&entries)?;
let filtered_entries = filter(&entries, &entry_mask)?;
let entries = filtered_entries.as_map();
let key_array = StringArray::new_scalar(key); // Keys shouldn't ever be null by definition in arrow, but doing this second filter to be careful
let key_mask = eq(entries.keys(), &key_array)?;
filter(entries, &key_mask).map_err(Error::generic_err)
if let Some(key) = key {
let entries = source.as_map();
let key_array = StringArray::new_scalar(key); // Keys shouldn't ever be null by definition in arrow, but doing this second filter to be careful
let key_mask = eq(entries.keys(), &key_array)?;
filter(entries, &key_mask).map_err(Error::generic_err)
} else {
Ok(source)
}
}
}
}
Expand Down Expand Up @@ -462,16 +463,50 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {

#[cfg(test)]
mod tests {
use std::ops::{Add, Div, Mul, Sub};

use arrow_array::{GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};

use super::*;
use crate::expressions::*;
use crate::schema::ArrayType;
use crate::DataType as DeltaDataTypes;
use arrow_array::builder::{MapBuilder, MapFieldNames, StringBuilder};
use arrow_array::{GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};
use std::collections::HashMap;
use std::ops::{Add, Div, Mul, Sub};

fn setup_map_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new_map(
"test_map",
MAP_ROOT_DEFAULT,
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Utf8, true),
false,
true,
)]))
}

fn setup_map_array(map: HashMap<String, Option<String>>) -> DeltaResult<Arc<MapArray>> {
let mut array_builder = MapBuilder::new(
Some(MapFieldNames {
entry: MAP_ROOT_DEFAULT.to_string(),
..Default::default()
}),
StringBuilder::new(),
StringBuilder::new(),
);

for (k, v) in map {
array_builder.keys().append_value(k);
if let Some(v) = v {
array_builder.values().append_value(v);
} else {
array_builder.values().append_null();
}
array_builder.append(true)?;
}

Ok(Arc::new(array_builder.finish()))
}

#[test]
fn test_array_column() {
Expand Down Expand Up @@ -788,4 +823,51 @@ mod tests {
let expected = Arc::new(BooleanArray::from(vec![true, false]));
assert_eq!(results.as_ref(), expected.as_ref());
}

#[test]
fn test_map_expression_access() -> DeltaResult<()> {
let map_access = Expression::map(Expression::column("test_map"), Some("first_key"));

let schema = setup_map_schema();
let map_values = HashMap::from_iter([
("first_key".to_string(), Some("first".to_string())),
("second_key".to_string(), Some("second_value".to_string())),
]);
let array = setup_map_array(map_values)?;
let expected = HashMap::from_iter([("first_key".to_string(), Some("first".to_string()))]);
let expected_array = setup_map_array(expected)?;

let batch = RecordBatch::try_new(schema.clone(), vec![array])?;
let output = evaluate_expression(&map_access, &batch, None)?;

assert_eq!(output.len(), 1);
assert_eq!(output.as_ref(), expected_array.as_ref());

Ok(())
}

#[test]
fn test_map_expression_eq() -> DeltaResult<()> {
let map_access = Expression::map(Expression::column("test_map"), None);
let predicate_expr = Expression::binary(
BinaryOperator::Equal,
map_access.clone(),
Expression::literal("second_value"),
);

let schema = setup_map_schema();
let map_values = HashMap::from_iter([
("first_key".to_string(), Some("first".to_string())),
("second_key".to_string(), Some("second_value".to_string())),
]);
let array = setup_map_array(map_values)?;

let batch = RecordBatch::try_new(schema.clone(), vec![array])?;
let output = evaluate_expression(&predicate_expr, &batch, None)?;
let expected = Arc::new(BooleanArray::from(vec![false, true]));
assert_eq!(output.len(), 2);
assert_eq!(output.as_ref(), expected.as_ref());

Ok(())
}
}
12 changes: 8 additions & 4 deletions kernel/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ pub enum Expression {
},
MapAccess {
source: Box<Expression>,
key: String,
key: Option<String>,
},
}

Expand Down Expand Up @@ -204,7 +204,11 @@ impl Display for Expression {
}
},
Self::MapAccess { source, key } => {
write!(f, "{}[{}]", *source, key)
write!(f, "{}", *source)?;
if let Some(key) = key {
write!(f, "[{}]", key)?;
}
Ok(())
}
}
}
Expand Down Expand Up @@ -266,10 +270,10 @@ impl Expression {
Self::VariadicOperation { op, exprs }
}

pub fn map(source: impl Into<Expression>, key: impl ToString) -> Self {
pub fn map(source: impl Into<Expression>, key: Option<&str>) -> Self {
Self::MapAccess {
source: Box::new(source.into()),
key: key.to_string(),
key: key.map(Into::into),
}
}

Expand Down
56 changes: 0 additions & 56 deletions kernel/src/scan/data_skipping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,6 @@ impl DataSkippingFilter {
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::arrow_data::ArrowEngineData;
use crate::engine::arrow_expression::evaluate_expression;
use crate::engine::default::executor::tokio::TokioBackgroundExecutor;
use crate::engine::default::DefaultEngine;
use crate::scan::log_replay::SCAN_ROW_SCHEMA;
use crate::{Expression, Table};

#[test]
fn test_rewrite_basic_comparison() {
Expand Down Expand Up @@ -387,54 +381,4 @@ mod tests {
assert_eq!(rewritten, expected)
}
}

#[test]
fn test_data_skipping() -> DeltaResult<()> {
let table = Table::try_from_uri(
r"..\acceptance\tests\dat\out\reader_tests\generated\basic_partitioned\delta",
)?;
let engine = DefaultEngine::try_new(
table.location(),
std::iter::empty::<(String, String)>(),
Arc::new(TokioBackgroundExecutor::new()),
)?;
let snapshot = table.snapshot(&engine, Some(1))?;
let schema = Arc::new(snapshot.schema().clone());
let predicate_expr = Expression::binary(
BinaryOperator::Equal,
Expression::map(
Expression::column("fileConstantValues.partitionValues"),
"letter",
),
Expression::literal("e"),
);
let builder = snapshot
.into_scan_builder()
.with_schema(schema)
.with_predicate(predicate_expr.clone())
.build()?;

// let scan = builder.execute(&engine)?;
// for sf in scan {
// let data = sf.raw_data?;
// let batch = data.as_any().downcast_ref::<ArrowEngineData>().unwrap();
// print_batches(&[batch.record_batch().clone()])?
// }
let scan = builder.scan_data(&engine)?;

for sd in scan {
let sd = sd?;
let data = sd.0.as_any().downcast_ref::<ArrowEngineData>().unwrap();
let result = evaluate_expression(
&predicate_expr,
data.record_batch(),
Some(&DataType::struct_type(
SCAN_ROW_SCHEMA.fields().cloned().collect(),
)),
);
println!("{:?}", result);
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion kernel/src/scan/log_replay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ lazy_static! {
"fileConstantValues",
StructType::new(vec![StructField::new(
"partitionValues",
MapType::new(DataType::STRING, DataType::STRING, false),
MapType::new(DataType::STRING, DataType::STRING, true),
true,
)]),
true
Expand Down

0 comments on commit d2907d9

Please sign in to comment.