diff --git a/ffi/src/engine_funcs.rs b/ffi/src/engine_funcs.rs index 03ae289ed..73f418e23 100644 --- a/ffi/src/engine_funcs.rs +++ b/ffi/src/engine_funcs.rs @@ -180,7 +180,7 @@ fn new_expression_evaluator_impl( /// Caller is responsible for passing a valid handle. #[no_mangle] pub unsafe extern "C" fn free_expression_evaluator(evaluator: Handle) { - debug!("engine released evaluator"); + debug!("engine released expression evaluator"); evaluator.drop_handle(); } @@ -189,7 +189,7 @@ pub unsafe extern "C" fn free_expression_evaluator(evaluator: Handle, batch: &mut Handle, evaluator: Handle, @@ -197,11 +197,11 @@ pub unsafe extern "C" fn evaluate( let engine = unsafe { engine.clone_as_arc() }; let batch = unsafe { batch.as_mut() }; let evaluator = unsafe { evaluator.clone_as_arc() }; - let res = evaluate_impl(batch, evaluator.as_ref()); + let res = evaluate_expression_impl(batch, evaluator.as_ref()); res.into_extern_result(&engine.as_ref()) } -fn evaluate_impl( +fn evaluate_expression_impl( batch: &dyn EngineData, evaluator: &dyn ExpressionEvaluator, ) -> DeltaResult> { @@ -219,7 +219,7 @@ mod tests { use std::sync::Arc; #[test] - fn test_new_evaluator() { + fn test_new_expression_evaluator() { let engine = get_default_engine(); let in_schema = Arc::new(StructType::new(vec![StructField::new( "a", diff --git a/ffi/src/expressions/engine.rs b/ffi/src/expressions/engine.rs index f624a6a1a..aa8e84c43 100644 --- a/ffi/src/expressions/engine.rs +++ b/ffi/src/expressions/engine.rs @@ -6,14 +6,19 @@ use crate::{ AllocateErrorFn, EngineIterator, ExternResult, IntoExternResult, KernelStringSlice, ReferenceSet, TryFromStringSlice, }; -use delta_kernel::{ - expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator}, - DeltaResult, +use delta_kernel::expressions::{ + BinaryExpressionOp, BinaryPredicateOp, ColumnName, Expression, Predicate, UnaryPredicateOp, }; +use delta_kernel::DeltaResult; + +pub enum ExpressionOrPredicate { + Expression(Expression), + Predicate(Predicate), +} #[derive(Default)] pub struct KernelExpressionVisitorState { - inflight_ids: ReferenceSet, + inflight_ids: ReferenceSet, } /// A predicate that can be used to skip data when scanning. @@ -36,19 +41,40 @@ pub struct EnginePredicate { } fn wrap_expression(state: &mut KernelExpressionVisitorState, expr: impl Into) -> usize { - state.inflight_ids.insert(expr.into()) + state + .inflight_ids + .insert(ExpressionOrPredicate::Expression(expr.into())) +} + +fn wrap_predicate(state: &mut KernelExpressionVisitorState, pred: impl Into) -> usize { + state + .inflight_ids + .insert(ExpressionOrPredicate::Predicate(pred.into())) } pub(crate) fn unwrap_kernel_expression( state: &mut KernelExpressionVisitorState, exprid: usize, ) -> Option { - state.inflight_ids.take(exprid) + match state.inflight_ids.take(exprid)? { + ExpressionOrPredicate::Expression(expr) => Some(expr), + ExpressionOrPredicate::Predicate(pred) => Some(Expression::predicate(pred)), + } +} + +pub(crate) fn unwrap_kernel_predicate( + state: &mut KernelExpressionVisitorState, + predid: usize, +) -> Option { + match state.inflight_ids.take(predid)? { + ExpressionOrPredicate::Expression(expr) => Some(Predicate::from_expr(expr)), + ExpressionOrPredicate::Predicate(pred) => Some(pred), + } } fn visit_expression_binary( state: &mut KernelExpressionVisitorState, - op: BinaryOperator, + op: BinaryExpressionOp, a: usize, b: usize, ) -> usize { @@ -60,71 +86,120 @@ fn visit_expression_binary( } } -fn visit_expression_unary( +fn visit_predicate_binary( state: &mut KernelExpressionVisitorState, - op: UnaryOperator, + op: BinaryPredicateOp, + a: usize, + b: usize, +) -> usize { + let left = unwrap_kernel_expression(state, a); + let right = unwrap_kernel_expression(state, b); + match left.zip(right) { + Some((left, right)) => wrap_predicate(state, Predicate::binary(op, left, right)), + None => 0, // invalid child => invalid node + } +} + +fn visit_predicate_unary( + state: &mut KernelExpressionVisitorState, + op: UnaryPredicateOp, inner_expr: usize, ) -> usize { - unwrap_kernel_expression(state, inner_expr).map_or(0, |expr| { - wrap_expression(state, Expression::unary(op, expr)) - }) + unwrap_kernel_expression(state, inner_expr) + .map_or(0, |expr| wrap_predicate(state, Predicate::unary(op, expr))) } // The EngineIterator is not thread safe, not reentrant, not owned by callee, not freed by callee. #[no_mangle] -pub extern "C" fn visit_expression_and( +pub extern "C" fn visit_predicate_and( state: &mut KernelExpressionVisitorState, children: &mut EngineIterator, ) -> usize { - let result = Expression::and_from( - children.flat_map(|child| unwrap_kernel_expression(state, child as usize)), + let result = Predicate::and_from( + children.flat_map(|child| unwrap_kernel_predicate(state, child as usize)), ); - wrap_expression(state, result) + wrap_predicate(state, result) } #[no_mangle] -pub extern "C" fn visit_expression_lt( +pub extern "C" fn visit_expression_plus( state: &mut KernelExpressionVisitorState, a: usize, b: usize, ) -> usize { - visit_expression_binary(state, BinaryOperator::LessThan, a, b) + visit_expression_binary(state, BinaryExpressionOp::Plus, a, b) } #[no_mangle] -pub extern "C" fn visit_expression_le( +pub extern "C" fn visit_expression_minus( state: &mut KernelExpressionVisitorState, a: usize, b: usize, ) -> usize { - visit_expression_binary(state, BinaryOperator::LessThanOrEqual, a, b) + visit_expression_binary(state, BinaryExpressionOp::Minus, a, b) } #[no_mangle] -pub extern "C" fn visit_expression_gt( +pub extern "C" fn visit_expression_multiply( state: &mut KernelExpressionVisitorState, a: usize, b: usize, ) -> usize { - visit_expression_binary(state, BinaryOperator::GreaterThan, a, b) + visit_expression_binary(state, BinaryExpressionOp::Multiply, a, b) } #[no_mangle] -pub extern "C" fn visit_expression_ge( +pub extern "C" fn visit_expression_divide( state: &mut KernelExpressionVisitorState, a: usize, b: usize, ) -> usize { - visit_expression_binary(state, BinaryOperator::GreaterThanOrEqual, a, b) + visit_expression_binary(state, BinaryExpressionOp::Divide, a, b) } #[no_mangle] -pub extern "C" fn visit_expression_eq( +pub extern "C" fn visit_predicate_lt( state: &mut KernelExpressionVisitorState, a: usize, b: usize, ) -> usize { - visit_expression_binary(state, BinaryOperator::Equal, a, b) + visit_predicate_binary(state, BinaryPredicateOp::LessThan, a, b) +} + +#[no_mangle] +pub extern "C" fn visit_predicate_le( + state: &mut KernelExpressionVisitorState, + a: usize, + b: usize, +) -> usize { + visit_predicate_binary(state, BinaryPredicateOp::LessThanOrEqual, a, b) +} + +#[no_mangle] +pub extern "C" fn visit_predicate_gt( + state: &mut KernelExpressionVisitorState, + a: usize, + b: usize, +) -> usize { + visit_predicate_binary(state, BinaryPredicateOp::GreaterThan, a, b) +} + +#[no_mangle] +pub extern "C" fn visit_predicate_ge( + state: &mut KernelExpressionVisitorState, + a: usize, + b: usize, +) -> usize { + visit_predicate_binary(state, BinaryPredicateOp::GreaterThanOrEqual, a, b) +} + +#[no_mangle] +pub extern "C" fn visit_predicate_eq( + state: &mut KernelExpressionVisitorState, + a: usize, + b: usize, +) -> usize { + visit_predicate_binary(state, BinaryPredicateOp::Equal, a, b) } /// # Safety @@ -148,19 +223,20 @@ fn visit_expression_column_impl( } #[no_mangle] -pub extern "C" fn visit_expression_not( +pub extern "C" fn visit_predicate_not( state: &mut KernelExpressionVisitorState, - inner_expr: usize, + inner_pred: usize, ) -> usize { - visit_expression_unary(state, UnaryOperator::Not, inner_expr) + unwrap_kernel_predicate(state, inner_pred) + .map_or(0, |pred| wrap_predicate(state, Predicate::not(pred))) } #[no_mangle] -pub extern "C" fn visit_expression_is_null( +pub extern "C" fn visit_predicate_is_null( state: &mut KernelExpressionVisitorState, inner_expr: usize, ) -> usize { - visit_expression_unary(state, UnaryOperator::IsNull, inner_expr) + visit_predicate_unary(state, UnaryPredicateOp::IsNull, inner_expr) } /// # Safety @@ -178,7 +254,7 @@ fn visit_expression_literal_string_impl( state: &mut KernelExpressionVisitorState, value: DeltaResult, ) -> DeltaResult { - Ok(wrap_expression(state, value?)) + Ok(wrap_expression(state, Expression::literal(value?))) } // We need to get parse.expand working to be able to macro everything below, see issue #255 @@ -188,7 +264,7 @@ pub extern "C" fn visit_expression_literal_int( state: &mut KernelExpressionVisitorState, value: i32, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -196,7 +272,7 @@ pub extern "C" fn visit_expression_literal_long( state: &mut KernelExpressionVisitorState, value: i64, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -204,7 +280,7 @@ pub extern "C" fn visit_expression_literal_short( state: &mut KernelExpressionVisitorState, value: i16, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -212,7 +288,7 @@ pub extern "C" fn visit_expression_literal_byte( state: &mut KernelExpressionVisitorState, value: i8, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -220,7 +296,7 @@ pub extern "C" fn visit_expression_literal_float( state: &mut KernelExpressionVisitorState, value: f32, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -228,7 +304,7 @@ pub extern "C" fn visit_expression_literal_double( state: &mut KernelExpressionVisitorState, value: f64, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } #[no_mangle] @@ -236,5 +312,5 @@ pub extern "C" fn visit_expression_literal_bool( state: &mut KernelExpressionVisitorState, value: bool, ) -> usize { - wrap_expression(state, value) + wrap_expression(state, Expression::literal(value)) } diff --git a/ffi/src/expressions/kernel.rs b/ffi/src/expressions/kernel.rs index 9a73e9017..3c163ad10 100644 --- a/ffi/src/expressions/kernel.rs +++ b/ffi/src/expressions/kernel.rs @@ -1,12 +1,13 @@ //! Defines [`EngineExpressionVisitor`]. This is a visitor that can be used to convert the kernel's //! [`Expression`] to an engine's expression format. -use crate::expressions::SharedExpression; +use crate::expressions::{SharedExpression, SharedPredicate}; use std::ffi::c_void; use crate::{handle::Handle, kernel_string_slice, KernelStringSlice}; use delta_kernel::expressions::{ - ArrayData, BinaryExpression, BinaryOperator, Expression, JunctionExpression, JunctionOperator, - Scalar, StructData, UnaryExpression, UnaryOperator, + ArrayData, BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, + Expression, JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, StructData, + UnaryPredicate, UnaryPredicateOp, }; /// Free the memory the passed SharedExpression @@ -14,7 +15,16 @@ use delta_kernel::expressions::{ /// # Safety /// Engine is responsible for passing a valid SharedExpression #[no_mangle] -pub unsafe extern "C" fn free_kernel_predicate(data: Handle) { +pub unsafe extern "C" fn free_kernel_expression(data: Handle) { + data.drop_handle(); +} + +/// Free the memory the passed SharedPredicate +/// +/// # Safety +/// Engine is responsible for passing a valid SharedPredicate +#[no_mangle] +pub unsafe extern "C" fn free_kernel_predicate(data: Handle) { data.drop_handle(); } @@ -25,7 +35,7 @@ type VisitJunctionFn = extern "C" fn(data: *mut c_void, sibling_list_id: usize, child_list_id: usize); /// The [`EngineExpressionVisitor`] defines a visitor system to allow engines to build their own -/// representation of a kernel expression. +/// representation of a kernel expression or predicate. /// /// The model is list based. When the kernel needs a list, it will ask engine to allocate one of a /// particular size. Once allocated the engine returns an `id`, which can be any integer identifier @@ -208,6 +218,39 @@ pub unsafe extern "C" fn visit_expression_ref( visit_expression_internal(expression, visitor) } +/// Visit the predicate of the passed [`SharedPredicate`] Handle using the provided `visitor`. +/// See the documentation of [`EngineExpressionVisitor`] for a description of how this visitor +/// works. +/// +/// This method returns the id that the engine generated for the top level predicate +/// +/// # Safety +/// +/// The caller must pass a valid SharedPredicate Handle and expression visitor +#[no_mangle] +pub unsafe extern "C" fn visit_predicate( + predicate: &Handle, + visitor: &mut EngineExpressionVisitor, +) -> usize { + visit_predicate_internal(predicate.as_ref(), visitor) +} + +/// Visit the predicate of the passed [`Predicate`] pointer using the provided `visitor`. See the +/// documentation of [`EngineExpressionVisitor`] for a description of how this visitor works. +/// +/// This method returns the id that the engine generated for the top level predicate +/// +/// # Safety +/// +/// The caller must pass a valid Predicate pointer and expression visitor +#[no_mangle] +pub unsafe extern "C" fn visit_predicate_ref( + predicate: &Predicate, + visitor: &mut EngineExpressionVisitor, +) -> usize { + visit_predicate_internal(predicate, visitor) +} + macro_rules! call { ( $visitor:ident, $visitor_fn:ident $(, $extra_args:expr) *) => { ($visitor.$visitor_fn)($visitor.data $(, $extra_args) *) @@ -266,20 +309,20 @@ fn visit_expression_struct( call!(visitor, visit_struct_expr, sibling_list_id, child_list_id) } -fn visit_expression_junction( +fn visit_predicate_junction( visitor: &mut EngineExpressionVisitor, - op: &JunctionOperator, - exprs: &[Expression], + op: &JunctionPredicateOp, + preds: &[Predicate], sibling_list_id: usize, ) { - let child_list_id = call!(visitor, make_field_list, exprs.len()); - for expr in exprs { - visit_expression_impl(visitor, expr, child_list_id); + let child_list_id = call!(visitor, make_field_list, preds.len()); + for pred in preds { + visit_predicate_impl(visitor, pred, child_list_id); } let visit_fn = match op { - JunctionOperator::And => &visitor.visit_and, - JunctionOperator::Or => &visitor.visit_or, + JunctionPredicateOp::And => &visitor.visit_and, + JunctionPredicateOp::Or => &visitor.visit_or, }; visit_fn(visitor.data, sibling_list_id, child_list_id); } @@ -349,38 +392,61 @@ fn visit_expression_impl( call!(visitor, visit_column, sibling_list_id, name); } Expression::Struct(exprs) => visit_expression_struct(visitor, exprs, sibling_list_id), + Expression::Predicate(pred) => visit_predicate_impl(visitor, pred, sibling_list_id), Expression::Binary(BinaryExpression { op, left, right }) => { let child_list_id = call!(visitor, make_field_list, 2); visit_expression_impl(visitor, left, child_list_id); visit_expression_impl(visitor, right, child_list_id); let op = match op { - BinaryOperator::Plus => visitor.visit_add, - BinaryOperator::Minus => visitor.visit_minus, - BinaryOperator::Multiply => visitor.visit_multiply, - BinaryOperator::Divide => visitor.visit_divide, - BinaryOperator::LessThan => visitor.visit_lt, - BinaryOperator::LessThanOrEqual => visitor.visit_le, - BinaryOperator::GreaterThan => visitor.visit_gt, - BinaryOperator::GreaterThanOrEqual => visitor.visit_ge, - BinaryOperator::Equal => visitor.visit_eq, - BinaryOperator::NotEqual => visitor.visit_ne, - BinaryOperator::Distinct => visitor.visit_distinct, - BinaryOperator::In => visitor.visit_in, - BinaryOperator::NotIn => visitor.visit_not_in, + BinaryExpressionOp::Plus => visitor.visit_add, + BinaryExpressionOp::Minus => visitor.visit_minus, + BinaryExpressionOp::Multiply => visitor.visit_multiply, + BinaryExpressionOp::Divide => visitor.visit_divide, }; op(visitor.data, sibling_list_id, child_list_id); } - Expression::Unary(UnaryExpression { op, expr }) => { - let child_id_list = call!(visitor, make_field_list, 1); - visit_expression_impl(visitor, expr, child_id_list); + } +} + +fn visit_predicate_impl( + visitor: &mut EngineExpressionVisitor, + predicate: &Predicate, + sibling_list_id: usize, +) { + match predicate { + Predicate::BooleanExpression(expr) => visit_expression_impl(visitor, expr, sibling_list_id), + Predicate::Not(pred) => { + let child_list_id = call!(visitor, make_field_list, 1); + visit_predicate_impl(visitor, pred, child_list_id); + call!(visitor, visit_not, sibling_list_id, child_list_id); + } + Predicate::Unary(UnaryPredicate { op, expr }) => { + let child_list_id = call!(visitor, make_field_list, 1); + visit_expression_impl(visitor, expr, child_list_id); let op = match op { - UnaryOperator::Not => visitor.visit_not, - UnaryOperator::IsNull => visitor.visit_is_null, + UnaryPredicateOp::IsNull => visitor.visit_is_null, }; - op(visitor.data, sibling_list_id, child_id_list); + op(visitor.data, sibling_list_id, child_list_id); + } + Predicate::Binary(BinaryPredicate { op, left, right }) => { + let child_list_id = call!(visitor, make_field_list, 2); + visit_expression_impl(visitor, left, child_list_id); + visit_expression_impl(visitor, right, child_list_id); + let op = match op { + BinaryPredicateOp::LessThan => visitor.visit_lt, + BinaryPredicateOp::LessThanOrEqual => visitor.visit_le, + BinaryPredicateOp::GreaterThan => visitor.visit_gt, + BinaryPredicateOp::GreaterThanOrEqual => visitor.visit_ge, + BinaryPredicateOp::Equal => visitor.visit_eq, + BinaryPredicateOp::NotEqual => visitor.visit_ne, + BinaryPredicateOp::Distinct => visitor.visit_distinct, + BinaryPredicateOp::In => visitor.visit_in, + BinaryPredicateOp::NotIn => visitor.visit_not_in, + }; + op(visitor.data, sibling_list_id, child_list_id); } - Expression::Junction(JunctionExpression { op, exprs }) => { - visit_expression_junction(visitor, op, exprs, sibling_list_id) + Predicate::Junction(JunctionPredicate { op, preds }) => { + visit_predicate_junction(visitor, op, preds, sibling_list_id) } } } @@ -393,3 +459,9 @@ fn visit_expression_internal( visit_expression_impl(visitor, expression, top_level); top_level } + +fn visit_predicate_internal(predicate: &Predicate, visitor: &mut EngineExpressionVisitor) -> usize { + let top_level = call!(visitor, make_field_list, 1); + visit_predicate_impl(visitor, predicate, top_level); + top_level +} diff --git a/ffi/src/expressions/mod.rs b/ffi/src/expressions/mod.rs index a6756f972..1633e64e3 100644 --- a/ffi/src/expressions/mod.rs +++ b/ffi/src/expressions/mod.rs @@ -1,6 +1,6 @@ //! This module holds functionality for moving expressions across the FFI boundary, both from //! engine to kernel, and from kernel to engine. -use delta_kernel::Expression; +use delta_kernel::{Expression, Predicate}; use delta_kernel_ffi_macros::handle_descriptor; pub mod engine; @@ -8,3 +8,6 @@ pub mod kernel; #[handle_descriptor(target=Expression, mutable=false, sized=true)] pub struct SharedExpression; + +#[handle_descriptor(target=Predicate, mutable=false, sized=true)] +pub struct SharedPredicate; diff --git a/ffi/src/scan.rs b/ffi/src/scan.rs index 2cd3b9f6e..682665baf 100644 --- a/ffi/src/scan.rs +++ b/ffi/src/scan.rs @@ -12,7 +12,7 @@ use tracing::debug; use url::Url; use crate::expressions::engine::{ - unwrap_kernel_expression, EnginePredicate, KernelExpressionVisitorState, + unwrap_kernel_predicate, EnginePredicate, KernelExpressionVisitorState, }; use crate::expressions::SharedExpression; use crate::{ @@ -94,7 +94,7 @@ fn scan_impl( if let Some(predicate) = predicate { let mut visitor_state = KernelExpressionVisitorState::default(); let pred_id = (predicate.visitor)(predicate.predicate, &mut visitor_state); - let predicate = unwrap_kernel_expression(&mut visitor_state, pred_id); + let predicate = unwrap_kernel_predicate(&mut visitor_state, pred_id); debug!("Got predicate: {:#?}", predicate); scan_builder = scan_builder.with_predicate(predicate.map(Arc::new)); } diff --git a/ffi/src/test_ffi.rs b/ffi/src/test_ffi.rs index d59d5119a..bcde1ca6c 100644 --- a/ffi/src/test_ffi.rs +++ b/ffi/src/test_ffi.rs @@ -2,18 +2,22 @@ use std::sync::Arc; -use crate::{expressions::SharedExpression, handle::Handle}; +use crate::expressions::{SharedExpression, SharedPredicate}; +use crate::handle::Handle; use delta_kernel::{ - expressions::{column_expr, ArrayData, BinaryOperator, Expression as Expr, Scalar, StructData}, + expressions::{ + column_expr, column_pred, ArrayData, BinaryExpressionOp, BinaryPredicateOp, + Expression as Expr, Predicate as Pred, Scalar, StructData, + }, schema::{ArrayType, DataType, StructField, StructType}, }; -/// Constructs a kernel expression that is passed back as a SharedExpression handle. The expected +/// Constructs a kernel expression that is passed back as a [`SharedExpression`] handle. The expected /// output expression can be found in `ffi/tests/test_expression_visitor/expected.txt`. /// /// # Safety /// The caller is responsible for freeing the returned memory, either by calling -/// [`free_kernel_predicate`], or [`Handle::drop_handle`] +/// [`crate::expressions::free_kernel_expression`], or [`crate::handles::Handle::drop_handle`]. #[no_mangle] pub unsafe extern "C" fn get_testing_kernel_expression() -> Handle { let array_type = ArrayType::new( @@ -40,6 +44,7 @@ pub unsafe extern "C" fn get_testing_kernel_expression() -> Handle Handle Handle { + let array_type = ArrayType::new( + DataType::Primitive(delta_kernel::schema::PrimitiveType::Short), + false, + ); + let array_data = ArrayData::new(array_type.clone(), vec![Scalar::Short(5), Scalar::Short(0)]); + + let mut sub_exprs = vec![ + column_pred!("col"), + Pred::literal(true), + Pred::literal(false), + Pred::binary( + BinaryPredicateOp::In, + Expr::literal(10), + Scalar::Array(array_data.clone()), + ), + Pred::binary( + BinaryPredicateOp::NotIn, + Expr::literal(10), + Scalar::Array(array_data), + ), + Pred::or_from(vec![ + Pred::eq(Expr::literal(5), Expr::literal(10)), + Pred::ne(Expr::literal(20), Expr::literal(10)), + ]), + Pred::is_not_null(column_expr!("col")), + ]; + sub_exprs.extend( + [ + BinaryPredicateOp::Equal, + BinaryPredicateOp::NotEqual, + BinaryPredicateOp::LessThan, + BinaryPredicateOp::LessThanOrEqual, + BinaryPredicateOp::GreaterThan, + BinaryPredicateOp::GreaterThanOrEqual, + BinaryPredicateOp::Distinct, + ] + .into_iter() + .map(|op| Pred::binary(op, Expr::literal(0), Expr::literal(0))), + ); + + Arc::new(Pred::and_from(sub_exprs)).into() } diff --git a/kernel/src/actions/set_transaction.rs b/kernel/src/actions/set_transaction.rs index 92c4ae2d8..df478a4ec 100644 --- a/kernel/src/actions/set_transaction.rs +++ b/kernel/src/actions/set_transaction.rs @@ -4,7 +4,7 @@ use crate::actions::get_log_txn_schema; use crate::actions::visitors::SetTransactionVisitor; use crate::actions::{SetTransaction, SET_TRANSACTION_NAME}; use crate::snapshot::Snapshot; -use crate::{DeltaResult, Engine, EngineData, Expression as Expr, ExpressionRef, RowVisitor as _}; +use crate::{DeltaResult, Engine, EngineData, Expression as Expr, PredicateRef, RowVisitor as _}; pub(crate) use crate::actions::visitors::SetTransactionMap; @@ -51,7 +51,7 @@ impl SetTransactionScanner { // checkpoint part when patitioned by `add.path` like the Delta spec requires. There's no // point filtering by a particular app id, even if we have one, because app ids are all in // the a single checkpoint part having large min/max range (because they're usually uuids). - static META_PREDICATE: LazyLock> = LazyLock::new(|| { + static META_PREDICATE: LazyLock> = LazyLock::new(|| { Some(Arc::new( Expr::column([SET_TRANSACTION_NAME, "appId"]).is_not_null(), )) diff --git a/kernel/src/engine/arrow_expression/evaluate_expression.rs b/kernel/src/engine/arrow_expression/evaluate_expression.rs index bfb667e90..5bba52445 100644 --- a/kernel/src/engine/arrow_expression/evaluate_expression.rs +++ b/kernel/src/engine/arrow_expression/evaluate_expression.rs @@ -14,8 +14,8 @@ use crate::arrow::error::ArrowError; use crate::engine::arrow_utils::prim_array_cmp; use crate::error::{DeltaResult, Error}; use crate::expressions::{ - BinaryExpression, BinaryOperator, Expression, JunctionExpression, JunctionOperator, Scalar, - UnaryExpression, UnaryOperator, + BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, Expression, + JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, UnaryPredicate, UnaryPredicateOp, }; use crate::schema::DataType; use itertools::Itertools; @@ -27,10 +27,6 @@ fn downcast_to_bool(arr: &dyn Array) -> DeltaResult<&BooleanArray> { .ok_or_else(|| Error::generic("expected boolean array")) } -fn wrap_comparison_result(arr: BooleanArray) -> ArrayRef { - Arc::new(arr) as _ -} - trait ProvidesColumnByName { fn column_by_name(&self, name: &str) -> Option<&ArrayRef>; } @@ -87,7 +83,7 @@ pub(crate) fn evaluate_expression( batch: &RecordBatch, result_type: Option<&DataType>, ) -> DeltaResult { - use BinaryOperator::*; + use BinaryExpressionOp::*; use Expression::*; match (expression, result_type) { (Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?), @@ -115,29 +111,63 @@ pub(crate) fn evaluate_expression( (Struct(_), _) => Err(Error::generic( "Data type is required to evaluate struct expressions", )), - (Unary(UnaryExpression { op, expr }), _) => { + (Predicate(pred), None | Some(&DataType::BOOLEAN)) => { + let result = evaluate_predicate(pred, batch)?; + Ok(Arc::new(result)) + } + (Predicate(_), Some(data_type)) => Err(Error::generic(format!( + "Unexpected data type: {data_type:?}" + ))), + (Binary(BinaryExpression { op, left, right }), _) => { + let left_arr = evaluate_expression(left.as_ref(), batch, None)?; + let right_arr = evaluate_expression(right.as_ref(), batch, None)?; + + type Operation = fn(&dyn Datum, &dyn Datum) -> Result; + let eval: Operation = match op { + Plus => add, + Minus => sub, + Multiply => mul, + Divide => div, + }; + + Ok(eval(&left_arr, &right_arr)?) + } + } +} + +pub(crate) fn evaluate_predicate( + predicate: &Predicate, + batch: &RecordBatch, +) -> DeltaResult { + use BinaryPredicateOp::*; + use Predicate::*; + match predicate { + BooleanExpression(expr) => { + // Grr -- there's no way to cast an `Arc` back to its native type, so we + // can't use `Arc::into_inner` here and must unconditionally clone instead. + let arr = evaluate_expression(expr, batch, Some(&DataType::BOOLEAN))?; + Ok(downcast_to_bool(&arr)?.clone()) + } + Not(pred) => Ok(not(&evaluate_predicate(pred, batch)?)?), + Unary(UnaryPredicate { op, expr }) => { let arr = evaluate_expression(expr.as_ref(), batch, None)?; let result = match op { - UnaryOperator::Not => not(downcast_to_bool(&arr)?)?, - UnaryOperator::IsNull => is_null(&arr)?, + UnaryPredicateOp::IsNull => is_null(&arr)?, }; - Ok(Arc::new(result)) + Ok(result) } - ( - Binary(BinaryExpression { - op: In, - left, - right, - }), - _, - ) => match (left.as_ref(), right.as_ref()) { - (Literal(_), Column(_)) => { + Binary(BinaryPredicate { + op: In, + left, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expression::Literal(_), Expression::Column(_)) => { let left_arr = evaluate_expression(left.as_ref(), batch, None)?; let right_arr = evaluate_expression(right.as_ref(), batch, None)?; if let Some(string_arr) = left_arr.as_string_opt::() { if let Some(right_arr) = right_arr.as_list_opt::() { let result = in_list_utf8(string_arr, right_arr)?; - return Ok(wrap_comparison_result(result)); + return Ok(result); } } prim_array_cmp! { @@ -174,70 +204,54 @@ pub(crate) fn evaluate_expression( (ArrowDataType::Decimal256(_, _), Decimal256Type) } } - (Literal(lit), Literal(Scalar::Array(ad))) => { + (Expression::Literal(lit), Expression::Literal(Scalar::Array(ad))) => { #[allow(deprecated)] let exists = ad.array_elements().contains(lit); - Ok(Arc::new(BooleanArray::from(vec![exists]))) + Ok(BooleanArray::from(vec![exists])) } (l, r) => Err(Error::invalid_expression(format!( "Invalid right value for (NOT) IN comparison, left is: {l} right is: {r}" ))), }, - ( - Binary(BinaryExpression { - op: NotIn, - left, - right, - }), - _, - ) => { - let reverse_op = Expression::binary(In, *left.clone(), *right.clone()); - let reverse_expr = evaluate_expression(&reverse_op, batch, None)?; - let result = not(reverse_expr.as_boolean())?; - Ok(wrap_comparison_result(result)) + Binary(BinaryPredicate { + op: NotIn, + left, + right, + }) => { + let reverse_op = Predicate::binary(In, *left.clone(), *right.clone()); + let reverse_pred = evaluate_predicate(&reverse_op, batch)?; + Ok(not(&reverse_pred)?) } - (Binary(BinaryExpression { op, left, right }), _) => { + Binary(BinaryPredicate { op, left, right }) => { let left_arr = evaluate_expression(left.as_ref(), batch, None)?; let right_arr = evaluate_expression(right.as_ref(), batch, None)?; - type Operation = fn(&dyn Datum, &dyn Datum) -> Result; + type Operation = fn(&dyn Datum, &dyn Datum) -> Result; let eval: Operation = match op { - Plus => add, - Minus => sub, - Multiply => mul, - Divide => div, - LessThan => |l, r| lt(l, r).map(wrap_comparison_result), - LessThanOrEqual => |l, r| lt_eq(l, r).map(wrap_comparison_result), - GreaterThan => |l, r| gt(l, r).map(wrap_comparison_result), - GreaterThanOrEqual => |l, r| gt_eq(l, r).map(wrap_comparison_result), - Equal => |l, r| eq(l, r).map(wrap_comparison_result), - NotEqual => |l, r| neq(l, r).map(wrap_comparison_result), - Distinct => |l, r| distinct(l, r).map(wrap_comparison_result), + LessThan => |l, r| lt(l, r), + LessThanOrEqual => |l, r| lt_eq(l, r), + GreaterThan => |l, r| gt(l, r), + GreaterThanOrEqual => |l, r| gt_eq(l, r), + Equal => |l, r| eq(l, r), + NotEqual => |l, r| neq(l, r), + Distinct => |l, r| distinct(l, r), // NOTE: [Not]In was already covered above In | NotIn => return Err(Error::generic("Invalid expression given")), }; Ok(eval(&left_arr, &right_arr)?) } - (Junction(JunctionExpression { op, exprs }), None | Some(&DataType::BOOLEAN)) => { + Junction(JunctionPredicate { op, preds }) => { type Operation = fn(&BooleanArray, &BooleanArray) -> Result; let (reducer, default): (Operation, _) = match op { - JunctionOperator::And => (and_kleene, true), - JunctionOperator::Or => (or_kleene, false), + JunctionPredicateOp::And => (and_kleene, true), + JunctionPredicateOp::Or => (or_kleene, false), }; - exprs + preds .iter() - .map(|expr| evaluate_expression(expr, batch, result_type)) - .reduce(|l, r| { - let result = reducer(downcast_to_bool(&l?)?, downcast_to_bool(&r?)?)?; - Ok(wrap_comparison_result(result)) - }) - .unwrap_or_else(|| { - evaluate_expression(&Expression::literal(default), batch, result_type) - }) + .map(|pred| evaluate_predicate(pred, batch)) + .reduce(|l, r| Ok(reducer(&l?, &r?)?)) + .unwrap_or_else(|| Ok(BooleanArray::from(vec![default; batch.num_rows()]))) } - (Junction(_), _) => Err(Error::Generic(format!( - "Junction {expression:?} is expected to return boolean results, got {result_type:?}" - ))), } } diff --git a/kernel/src/engine/arrow_expression/mod.rs b/kernel/src/engine/arrow_expression/mod.rs index 019531931..5f75f024c 100644 --- a/kernel/src/engine/arrow_expression/mod.rs +++ b/kernel/src/engine/arrow_expression/mod.rs @@ -15,15 +15,15 @@ use crate::arrow::datatypes::{ use super::arrow_conversion::LIST_ARRAY_ROOT; use crate::engine::arrow_data::ArrowEngineData; use crate::error::{DeltaResult, Error}; -use crate::expressions::{Expression, Scalar}; +use crate::expressions::{Expression, Predicate, Scalar}; use crate::schema::{DataType, PrimitiveType, SchemaRef}; -use crate::{EngineData, EvaluationHandler, ExpressionEvaluator}; +use crate::{EngineData, EvaluationHandler, ExpressionEvaluator, PredicateEvaluator}; use itertools::Itertools; use tracing::debug; use apply_schema::{apply_schema, apply_schema_to}; -use evaluate_expression::evaluate_expression; +use evaluate_expression::{evaluate_expression, evaluate_predicate}; mod apply_schema; mod evaluate_expression; @@ -134,11 +134,22 @@ impl EvaluationHandler for ArrowEvaluationHandler { ) -> Arc { Arc::new(DefaultExpressionEvaluator { input_schema: schema, - expression: Box::new(expression), + expression, output_type, }) } + fn new_predicate_evaluator( + &self, + schema: SchemaRef, + predicate: Predicate, + ) -> Arc { + Arc::new(DefaultPredicateEvaluator { + input_schema: schema, + predicate, + }) + } + /// Create a single-row array with all-null leaf values. Note that if a nested struct is /// included in the `output_type`, the entire struct will be NULL (instead of a not-null struct /// with NULL fields). @@ -156,16 +167,13 @@ impl EvaluationHandler for ArrowEvaluationHandler { #[derive(Debug)] pub struct DefaultExpressionEvaluator { input_schema: SchemaRef, - expression: Box, + expression: Expression, output_type: DataType, } impl ExpressionEvaluator for DefaultExpressionEvaluator { fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult> { - debug!( - "Arrow evaluator evaluating: {:#?}", - self.expression.as_ref() - ); + debug!("Arrow evaluator evaluating: {:#?}", self.expression); let batch = batch .any_ref() .downcast_ref::() @@ -192,3 +200,37 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator { Ok(Box::new(ArrowEngineData::new(batch))) } } + +#[derive(Debug)] +pub struct DefaultPredicateEvaluator { + input_schema: SchemaRef, + predicate: Predicate, +} + +impl PredicateEvaluator for DefaultPredicateEvaluator { + fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult> { + debug!("Arrow evaluator evaluating: {:#?}", self.predicate); + let batch = batch + .any_ref() + .downcast_ref::() + .ok_or_else(|| Error::engine_data_type("ArrowEngineData"))? + .record_batch(); + let _input_schema: ArrowSchema = self.input_schema.as_ref().try_into()?; + // TODO: make sure we have matching schemas for validation + // if batch.schema().as_ref() != &input_schema { + // return Err(Error::Generic(format!( + // "input schema does not match batch schema: {:?} != {:?}", + // input_schema, + // batch.schema() + // ))); + // }; + let array = evaluate_predicate(&self.predicate, batch)?; + let schema = ArrowSchema::new(vec![ArrowField::new( + "output", + ArrowDataType::Boolean, + true, + )]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; + Ok(Box::new(ArrowEngineData::new(batch))) + } +} diff --git a/kernel/src/engine/arrow_expression/tests.rs b/kernel/src/engine/arrow_expression/tests.rs index bbe079064..b878b82ec 100644 --- a/kernel/src/engine/arrow_expression/tests.rs +++ b/kernel/src/engine/arrow_expression/tests.rs @@ -12,6 +12,9 @@ use crate::schema::{ArrayType, StructField, StructType}; use crate::DataType as DeltaDataTypes; use crate::EvaluationHandlerExtension as _; +use Expression as Expr; +use Predicate as Pred; + #[test] fn test_array_column() { let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]); @@ -24,17 +27,25 @@ fn test_array_column() { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); + let not_op = Pred::binary( + BinaryPredicateOp::NotIn, + Expr::literal(5), + column_expr!("item"), + ); - let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item")); + let in_op = Pred::binary( + BinaryPredicateOp::In, + Expr::literal(5), + column_expr!("item"), + ); - let result = evaluate_expression(¬_op, &batch, None).unwrap(); + let result = evaluate_predicate(¬_op, &batch).unwrap(); let expected = BooleanArray::from(vec![true, false, true]); - assert_eq!(result.as_ref(), &expected); + assert_eq!(result, expected); - let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); - let in_expected = BooleanArray::from(vec![false, true, false]); - assert_eq!(in_result.as_ref(), &in_expected); + let result = evaluate_predicate(&in_op, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, true, false]); + assert_eq!(result, expected); } #[test] @@ -44,9 +55,13 @@ fn test_bad_right_type_array() { let schema = Schema::new([field.clone()]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); - let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); + let in_op = Pred::binary( + BinaryPredicateOp::NotIn, + Expr::literal(5), + column_expr!("item"), + ); - let in_result = evaluate_expression(&in_op, &batch, None); + let in_result = evaluate_predicate(&in_op, &batch); assert!(in_result.is_err()); assert_eq!( @@ -61,18 +76,18 @@ fn test_literal_type_array() { let schema = Schema::new([field.clone()]); let batch = RecordBatch::new_empty(Arc::new(schema)); - let in_op = Expression::binary( - BinaryOperator::NotIn, - 5, + let in_op = Pred::binary( + BinaryPredicateOp::NotIn, + Expr::literal(5), Scalar::Array(ArrayData::new( ArrayType::new(DeltaDataTypes::INTEGER, false), vec![Scalar::Integer(1), Scalar::Integer(2)], )), ); - let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); + let in_result = evaluate_predicate(&in_op, &batch).unwrap(); let in_expected = BooleanArray::from(vec![true]); - assert_eq!(in_result.as_ref(), &in_expected); + assert_eq!(in_result, in_expected); } #[test] @@ -87,13 +102,13 @@ fn test_invalid_array_sides() { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let in_op = Expression::binary( - BinaryOperator::NotIn, + let in_op = Pred::binary( + BinaryPredicateOp::NotIn, column_expr!("item"), column_expr!("item"), ); - let in_result = evaluate_expression(&in_op, &batch, None); + let in_result = evaluate_predicate(&in_op, &batch); assert!(in_result.is_err()); assert_eq!( @@ -114,17 +129,25 @@ fn test_str_arrays() { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item")); + let str_not_op = Pred::binary( + BinaryPredicateOp::NotIn, + Expr::literal("bye"), + column_expr!("item"), + ); - let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item")); + let str_in_op = Pred::binary( + BinaryPredicateOp::In, + Expr::literal("hi"), + column_expr!("item"), + ); - let result = evaluate_expression(&str_in_op, &batch, None).unwrap(); + let result = evaluate_predicate(&str_in_op, &batch).unwrap(); let expected = BooleanArray::from(vec![true, true, true]); - assert_eq!(result.as_ref(), &expected); + assert_eq!(result, expected); - let in_result = evaluate_expression(&str_not_op, &batch, None).unwrap(); - let in_expected = BooleanArray::from(vec![false, false, false]); - assert_eq!(in_result.as_ref(), &in_expected); + let result = evaluate_predicate(&str_not_op, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, false, false]); + assert_eq!(result, expected); } #[test] @@ -166,23 +189,23 @@ fn test_binary_op_scalar() { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); let column = column_expr!("a"); - let expression = column.clone().add(1); + let expression = column.clone().add(Expr::literal(1)); let results = evaluate_expression(&expression, &batch, None).unwrap(); let expected = Arc::new(Int32Array::from(vec![2, 3, 4])); assert_eq!(results.as_ref(), expected.as_ref()); - let expression = column.clone().sub(1); + let expression = column.clone().sub(Expr::literal(1)); let results = evaluate_expression(&expression, &batch, None).unwrap(); let expected = Arc::new(Int32Array::from(vec![0, 1, 2])); assert_eq!(results.as_ref(), expected.as_ref()); - let expression = column.clone().mul(2); + let expression = column.clone().mul(Expr::literal(2)); let results = evaluate_expression(&expression, &batch, None).unwrap(); let expected = Arc::new(Int32Array::from(vec![2, 4, 6])); assert_eq!(results.as_ref(), expected.as_ref()); // TODO handle type casting - let expression = column.div(1); + let expression = column.div(Expr::literal(1)); let results = evaluate_expression(&expression, &batch, None).unwrap(); let expected = Arc::new(Int32Array::from(vec![1, 2, 3])); assert_eq!(results.as_ref(), expected.as_ref()) @@ -226,35 +249,35 @@ fn test_binary_cmp() { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); let column = column_expr!("a"); - let expression = column.clone().lt(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = column.clone().lt_eq(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = column.clone().gt(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, false, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = column.clone().gt_eq(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, true, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = column.clone().eq(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = column.clone().ne(2); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false, true])); - assert_eq!(results.as_ref(), expected.as_ref()); + let predicate = column.clone().lt(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, false, false]); + assert_eq!(results, expected); + + let predicate = column.clone().le(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, true, false]); + assert_eq!(results, expected); + + let predicate = column.clone().gt(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, false, true]); + assert_eq!(results, expected); + + let predicate = column.clone().ge(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, true, true]); + assert_eq!(results, expected); + + let predicate = column.clone().eq(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, true, false]); + assert_eq!(results, expected); + + let predicate = column.clone().ne(Expr::literal(2)); + let results = evaluate_predicate(&predicate, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, false, true]); + assert_eq!(results, expected); } #[test] @@ -271,32 +294,28 @@ fn test_logical() { ], ) .unwrap(); - let column_a = column_expr!("a"); - let column_b = column_expr!("b"); - - let expression = Expression::and(column_a.clone(), column_b.clone()); - let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Expression::and(column_a.clone(), true); - let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Expression::or(column_a.clone(), column_b); - let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Expression::or(column_a.clone(), false); - let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); + let column_a = column_pred!("a"); + let column_b = column_pred!("b"); + + let pred = Pred::and(column_a.clone(), column_b.clone()); + let results = evaluate_predicate(&pred, &batch).unwrap(); + let expected = BooleanArray::from(vec![false, false]); + assert_eq!(results, expected); + + let pred = Pred::and(column_a.clone(), Pred::literal(true)); + let results = evaluate_predicate(&pred, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, false]); + assert_eq!(results, expected); + + let pred = Pred::or(column_a.clone(), column_b); + let results = evaluate_predicate(&pred, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, true]); + assert_eq!(results, expected); + + let pred = Pred::or(column_a.clone(), Pred::literal(false)); + let results = evaluate_predicate(&pred, &batch).unwrap(); + let expected = BooleanArray::from(vec![true, false]); + assert_eq!(results, expected); } #[test] diff --git a/kernel/src/engine/arrow_utils.rs b/kernel/src/engine/arrow_utils.rs index 749f1399c..807d51d40 100644 --- a/kernel/src/engine/arrow_utils.rs +++ b/kernel/src/engine/arrow_utils.rs @@ -41,7 +41,7 @@ macro_rules! prim_array_cmp { .ok_or(Error::invalid_expression( format!("Cannot cast to list array: {}", $right_arr.data_type())) )?; - crate::arrow::compute::kernels::comparison::in_list(prim_array, list_array).map(wrap_comparison_result) + crate::arrow::compute::kernels::comparison::in_list(prim_array, list_array) } )+ _ => Err(ArrowError::CastError( diff --git a/kernel/src/engine/default/json.rs b/kernel/src/engine/default/json.rs index decef41e3..533321f87 100644 --- a/kernel/src/engine/default/json.rs +++ b/kernel/src/engine/default/json.rs @@ -22,8 +22,7 @@ use crate::engine::arrow_utils::parse_json as arrow_parse_json; use crate::engine::arrow_utils::to_json_bytes; use crate::schema::SchemaRef; use crate::{ - DeltaResult, EngineData, Error, ExpressionRef, FileDataReadResultIterator, FileMeta, - JsonHandler, + DeltaResult, EngineData, Error, FileDataReadResultIterator, FileMeta, JsonHandler, PredicateRef, }; const DEFAULT_BUFFER_SIZE: usize = 1000; @@ -96,7 +95,7 @@ impl JsonHandler for DefaultJsonHandler { &self, files: &[FileMeta], physical_schema: SchemaRef, - _predicate: Option, + _predicate: Option, ) -> DeltaResult { if files.is_empty() { return Ok(Box::new(std::iter::empty())); diff --git a/kernel/src/engine/default/mod.rs b/kernel/src/engine/default/mod.rs index 24c6dddf7..76f97550d 100644 --- a/kernel/src/engine/default/mod.rs +++ b/kernel/src/engine/default/mod.rs @@ -38,7 +38,7 @@ pub struct DefaultEngine { storage: Arc>, json: Arc>, parquet: Arc>, - expression: Arc, + evaluation: Arc, } impl DefaultEngine { @@ -84,7 +84,7 @@ impl DefaultEngine { task_executor, )), object_store, - expression: Arc::new(ArrowEvaluationHandler {}), + evaluation: Arc::new(ArrowEvaluationHandler {}), } } @@ -121,7 +121,7 @@ impl DefaultEngine { impl Engine for DefaultEngine { fn evaluation_handler(&self) -> Arc { - self.expression.clone() + self.evaluation.clone() } fn storage_handler(&self) -> Arc { diff --git a/kernel/src/engine/default/parquet.rs b/kernel/src/engine/default/parquet.rs index 8636b3d9f..fbd2dfb7a 100644 --- a/kernel/src/engine/default/parquet.rs +++ b/kernel/src/engine/default/parquet.rs @@ -24,8 +24,8 @@ use crate::engine::default::executor::TaskExecutor; use crate::engine::parquet_row_group_skipping::ParquetRowGroupSkipping; use crate::schema::SchemaRef; use crate::{ - DeltaResult, EngineData, Error, ExpressionRef, FileDataReadResultIterator, FileMeta, - ParquetHandler, + DeltaResult, EngineData, Error, FileDataReadResultIterator, FileMeta, ParquetHandler, + PredicateRef, }; #[derive(Debug)] @@ -178,7 +178,7 @@ impl ParquetHandler for DefaultParquetHandler { &self, files: &[FileMeta], physical_schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult { if files.is_empty() { return Ok(Box::new(std::iter::empty())); @@ -221,7 +221,7 @@ struct ParquetOpener { // projection: Arc<[usize]>, batch_size: usize, table_schema: SchemaRef, - predicate: Option, + predicate: Option, limit: Option, store: Arc, } @@ -230,7 +230,7 @@ impl ParquetOpener { pub(crate) fn new( batch_size: usize, table_schema: SchemaRef, - predicate: Option, + predicate: Option, store: Arc, ) -> Self { Self { @@ -292,7 +292,7 @@ impl FileOpener for ParquetOpener { /// Implements [`FileOpener`] for a opening a parquet file from a presigned URL struct PresignedUrlOpener { batch_size: usize, - predicate: Option, + predicate: Option, limit: Option, table_schema: SchemaRef, client: reqwest::Client, @@ -302,7 +302,7 @@ impl PresignedUrlOpener { pub(crate) fn new( batch_size: usize, schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> Self { Self { batch_size, diff --git a/kernel/src/engine/parquet_row_group_skipping.rs b/kernel/src/engine/parquet_row_group_skipping.rs index 2464ca455..9231d235b 100644 --- a/kernel/src/engine/parquet_row_group_skipping.rs +++ b/kernel/src/engine/parquet_row_group_skipping.rs @@ -1,5 +1,5 @@ //! An implementation of parquet row group skipping using data skipping predicates over footer stats. -use crate::expressions::{ColumnName, Expression, Scalar}; +use crate::expressions::{ColumnName, Predicate, Scalar}; use crate::kernel_predicates::parquet_stats_skipping::ParquetStatsProvider; use crate::parquet::arrow::arrow_reader::ArrowReaderBuilder; use crate::parquet::file::metadata::RowGroupMetaData; @@ -17,10 +17,10 @@ mod tests; pub(crate) trait ParquetRowGroupSkipping { /// Instructs the parquet reader to perform row group skipping, eliminating any row group whose /// stats prove that none of the group's rows can satisfy the given `predicate`. - fn with_row_group_filter(self, predicate: &Expression) -> Self; + fn with_row_group_filter(self, predicate: &Predicate) -> Self; } impl ParquetRowGroupSkipping for ArrowReaderBuilder { - fn with_row_group_filter(self, predicate: &Expression) -> Self { + fn with_row_group_filter(self, predicate: &Predicate) -> Self { let indices = self .metadata() .row_groups() @@ -46,7 +46,7 @@ struct RowGroupFilter<'a> { impl<'a> RowGroupFilter<'a> { /// Creates a new row group filter for the given row group and predicate. - fn new(row_group: &'a RowGroupMetaData, predicate: &Expression) -> Self { + fn new(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> Self { Self { row_group, field_indices: compute_field_indices(row_group.schema_descr().columns(), predicate), @@ -54,7 +54,7 @@ impl<'a> RowGroupFilter<'a> { } /// Applies a filtering predicate to a row group. Return value false means to skip it. - fn apply(row_group: &'a RowGroupMetaData, predicate: &Expression) -> bool { + fn apply(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> bool { use crate::kernel_predicates::KernelPredicateEvaluator as _; RowGroupFilter::new(row_group, predicate).eval_sql_where(predicate) != Some(false) } @@ -216,19 +216,19 @@ impl ParquetStatsProvider for RowGroupFilter<'_> { } } -/// Given a filter expression of interest and a set of parquet column descriptors, build a column -> -/// index mapping for columns the expression references. This ensures O(1) lookup times, for an -/// overall O(n) cost to evaluate an expression tree with n nodes. +/// Given a predicate of interest and a set of parquet column descriptors, build a column -> +/// index mapping for columns the predicate references. This ensures O(1) lookup times, for an +/// overall O(n) cost to evaluate a predicate tree with n nodes. pub(crate) fn compute_field_indices( fields: &[ColumnDescPtr], - expression: &Expression, + predicate: &Predicate, ) -> HashMap { // Build up a set of requested column paths, then take each found path as the corresponding map // key (avoids unnecessary cloning). // // NOTE: If a requested column was not available, it is silently ignored. These missing columns // are implied all-null, so we will infer their min/max stats as NULL and nullcount == rowcount. - let mut requested_columns = expression.references(); + let mut requested_columns = predicate.references(); fields .iter() .enumerate() diff --git a/kernel/src/engine/parquet_row_group_skipping/tests.rs b/kernel/src/engine/parquet_row_group_skipping/tests.rs index 1ad2208db..058268df6 100644 --- a/kernel/src/engine/parquet_row_group_skipping/tests.rs +++ b/kernel/src/engine/parquet_row_group_skipping/tests.rs @@ -1,8 +1,8 @@ use super::*; -use crate::expressions::{column_expr, column_name}; +use crate::expressions::{column_name, column_pred}; use crate::kernel_predicates::DataSkippingPredicateEvaluator as _; use crate::parquet::arrow::arrow_reader::ArrowReaderMetadata; -use crate::Expression; +use crate::Predicate; use std::fs::File; /// Performs an exhaustive set of reads against a specially crafted parquet file. @@ -39,23 +39,23 @@ fn test_get_stat_values() { let file = File::open("./tests/data/parquet_row_group_skipping/part-00000-b92e017a-50ba-4676-8322-48fc371c2b59-c000.snappy.parquet").unwrap(); let metadata = ArrowReaderMetadata::load(&file, Default::default()).unwrap(); - // The expression doesn't matter -- it just needs to mention all the columns we care about. - let columns = Expression::and_from(vec![ - column_expr!("varlen.utf8"), - column_expr!("numeric.ints.int64"), - column_expr!("numeric.ints.int32"), - column_expr!("numeric.ints.int16"), - column_expr!("numeric.ints.int8"), - column_expr!("numeric.floats.float32"), - column_expr!("numeric.floats.float64"), - column_expr!("bool"), - column_expr!("varlen.binary"), - column_expr!("numeric.decimals.decimal32"), - column_expr!("numeric.decimals.decimal64"), - column_expr!("numeric.decimals.decimal128"), - column_expr!("chrono.date32"), - column_expr!("chrono.timestamp"), - column_expr!("chrono.timestamp_ntz"), + // The predicate doesn't matter -- it just needs to mention all the columns we care about. + let columns = Predicate::and_from(vec![ + column_pred!("varlen.utf8"), + column_pred!("numeric.ints.int64"), + column_pred!("numeric.ints.int32"), + column_pred!("numeric.ints.int16"), + column_pred!("numeric.ints.int8"), + column_pred!("numeric.floats.float32"), + column_pred!("numeric.floats.float64"), + column_pred!("bool"), + column_pred!("varlen.binary"), + column_pred!("numeric.decimals.decimal32"), + column_pred!("numeric.decimals.decimal64"), + column_pred!("numeric.decimals.decimal128"), + column_pred!("chrono.date32"), + column_pred!("chrono.timestamp"), + column_pred!("chrono.timestamp_ntz"), ]); let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns); diff --git a/kernel/src/engine/sync/json.rs b/kernel/src/engine/sync/json.rs index 3f630b7f2..19434122f 100644 --- a/kernel/src/engine/sync/json.rs +++ b/kernel/src/engine/sync/json.rs @@ -11,8 +11,7 @@ use crate::engine::arrow_utils::parse_json as arrow_parse_json; use crate::engine::arrow_utils::to_json_bytes; use crate::schema::SchemaRef; use crate::{ - DeltaResult, EngineData, Error, ExpressionRef, FileDataReadResultIterator, FileMeta, - JsonHandler, + DeltaResult, EngineData, Error, FileDataReadResultIterator, FileMeta, JsonHandler, PredicateRef, }; pub(crate) struct SyncJsonHandler; @@ -21,7 +20,7 @@ fn try_create_from_json( file: File, _schema: SchemaRef, arrow_schema: ArrowSchemaRef, - _predicate: Option, + _predicate: Option, ) -> DeltaResult>> { let json = ReaderBuilder::new(arrow_schema) .build(BufReader::new(file))? @@ -34,7 +33,7 @@ impl JsonHandler for SyncJsonHandler { &self, files: &[FileMeta], schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult { read_files(files, schema, predicate, try_create_from_json) } diff --git a/kernel/src/engine/sync/mod.rs b/kernel/src/engine/sync/mod.rs index 0c119396e..df2408797 100644 --- a/kernel/src/engine/sync/mod.rs +++ b/kernel/src/engine/sync/mod.rs @@ -3,8 +3,8 @@ use super::arrow_expression::ArrowEvaluationHandler; use crate::engine::arrow_data::ArrowEngineData; use crate::{ - DeltaResult, Engine, Error, EvaluationHandler, ExpressionRef, FileDataReadResultIterator, - FileMeta, JsonHandler, ParquetHandler, SchemaRef, StorageHandler, + DeltaResult, Engine, Error, EvaluationHandler, FileDataReadResultIterator, FileMeta, + JsonHandler, ParquetHandler, PredicateRef, SchemaRef, StorageHandler, }; use crate::arrow::datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; @@ -60,12 +60,12 @@ impl Engine for SyncEngine { fn read_files( files: &[FileMeta], schema: SchemaRef, - predicate: Option, + predicate: Option, mut try_create_from_file: F, ) -> DeltaResult where I: Iterator> + Send + 'static, - F: FnMut(File, SchemaRef, ArrowSchemaRef, Option) -> DeltaResult + F: FnMut(File, SchemaRef, ArrowSchemaRef, Option) -> DeltaResult + Send + 'static, { diff --git a/kernel/src/engine/sync/parquet.rs b/kernel/src/engine/sync/parquet.rs index 48010af30..65b7c0f73 100644 --- a/kernel/src/engine/sync/parquet.rs +++ b/kernel/src/engine/sync/parquet.rs @@ -8,7 +8,7 @@ use crate::engine::arrow_data::ArrowEngineData; use crate::engine::arrow_utils::{fixup_parquet_read, generate_mask, get_requested_indices}; use crate::engine::parquet_row_group_skipping::ParquetRowGroupSkipping; use crate::schema::SchemaRef; -use crate::{DeltaResult, ExpressionRef, FileDataReadResultIterator, FileMeta, ParquetHandler}; +use crate::{DeltaResult, FileDataReadResultIterator, FileMeta, ParquetHandler, PredicateRef}; pub(crate) struct SyncParquetHandler; @@ -16,7 +16,7 @@ fn try_create_from_parquet( file: File, schema: SchemaRef, _arrow_schema: ArrowSchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult>> { let metadata = ArrowReaderMetadata::load(&file, Default::default())?; let parquet_schema = metadata.schema(); @@ -37,7 +37,7 @@ impl ParquetHandler for SyncParquetHandler { &self, files: &[FileMeta], schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult { read_files(files, schema, predicate, try_create_from_parquet) } diff --git a/kernel/src/expressions/column_names.rs b/kernel/src/expressions/column_names.rs index 0ea7a7067..bf76e8005 100644 --- a/kernel/src/expressions/column_names.rs +++ b/kernel/src/expressions/column_names.rs @@ -420,6 +420,16 @@ macro_rules! __column_expr { #[doc(inline)] pub use __column_expr as column_expr; +#[macro_export] +#[doc(hidden)] +macro_rules! __column_pred { + ( $($name:tt)* ) => { + $crate::expressions::Predicate::from($crate::__column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __column_pred as column_pred; + #[macro_export] #[doc(hidden)] macro_rules! __joined_column_expr { diff --git a/kernel/src/expressions/literal_expression_transform.rs b/kernel/src/expressions/literal_expression_transform.rs index 2d2276c11..7219ad523 100644 --- a/kernel/src/expressions/literal_expression_transform.rs +++ b/kernel/src/expressions/literal_expression_transform.rs @@ -191,11 +191,13 @@ mod tests { use paste::paste; + use Expression as Expr; + // helper to take values/schema to pass to `create_one` and assert the result = expected fn assert_single_row_transform( values: &[Scalar], schema: SchemaRef, - expected: Result, + expected: Result, ) { let mut schema_transform = LiteralExpressionTransform::new(values); let datatype = schema.into(); @@ -221,15 +223,14 @@ mod tests { "col_1", DeltaDataTypes::INTEGER, )])); - let expected = Expression::null_literal(schema.clone().into()); + let expected = Expr::null_literal(schema.clone().into()); assert_single_row_transform(values, schema, Ok(expected)); let schema = Arc::new(StructType::new([StructField::nullable( "col_1", DeltaDataTypes::INTEGER, )])); - let expected = - Expression::struct_from(vec![Expression::null_literal(DeltaDataTypes::INTEGER)]); + let expected = Expr::struct_from(vec![Expr::null_literal(DeltaDataTypes::INTEGER)]); assert_single_row_transform(values, schema, Ok(expected)); } @@ -283,9 +284,9 @@ mod tests { ]), ), ])); - let expected = Expression::struct_from(vec![ - Expression::struct_from(vec![Expression::literal(1), Expression::literal(2)]), - Expression::struct_from(vec![Expression::literal(3), Expression::literal(4)]), + let expected = Expr::struct_from(vec![ + Expr::struct_from(vec![Expr::literal(1), Expr::literal(2)]), + Expr::struct_from(vec![Expr::literal(3), Expr::literal(4)]), ]); assert_single_row_transform(values, schema, Ok(expected)); } @@ -327,16 +328,16 @@ mod tests { let expected_result = match expected { Expected::Noop => { - let nested_struct = Expression::struct_from(vec![ - Expression::literal(values[0].clone()), - Expression::literal(values[1].clone()), + let nested_struct = Expr::struct_from(vec![ + Expr::literal(values[0].clone()), + Expr::literal(values[1].clone()), ]); - Ok(Expression::struct_from([nested_struct])) + Ok(Expr::struct_from([nested_struct])) } - Expected::Null => Ok(Expression::null_literal(schema.clone().into())), + Expected::Null => Ok(Expr::null_literal(schema.clone().into())), Expected::NullStruct => { - let nested_null = Expression::null_literal(field_x.data_type().clone()); - Ok(Expression::struct_from([nested_null])) + let nested_null = Expr::null_literal(field_x.data_type().clone()); + Ok(Expr::struct_from([nested_null])) } Expected::Error => Err(()), }; diff --git a/kernel/src/expressions/mod.rs b/kernel/src/expressions/mod.rs index ccc450008..acb74df70 100644 --- a/kernel/src/expressions/mod.rs +++ b/kernel/src/expressions/mod.rs @@ -1,33 +1,39 @@ //! Definitions and functions to create and manipulate kernel expressions -use std::borrow::Cow; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use itertools::Itertools; pub use self::column_names::{ - column_expr, column_name, joined_column_expr, joined_column_name, ColumnName, + column_expr, column_name, column_pred, joined_column_expr, joined_column_name, ColumnName, }; pub use self::scalars::{ArrayData, Scalar, StructData}; +use self::transforms::{ExpressionTransform as _, GetColumnReferences}; use crate::DataType; mod column_names; +pub(crate) mod literal_expression_transform; mod scalars; +pub mod transforms; -pub(crate) mod literal_expression_transform; +pub type ExpressionRef = std::sync::Arc; +pub type PredicateRef = std::sync::Arc; +//////////////////////////////////////////////////////////////////////// +// Operators +//////////////////////////////////////////////////////////////////////// + +/// A unary predicate operator. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum UnaryPredicateOp { + /// Unary Is Null + IsNull, +} + +/// A binary predicate operator. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -/// A binary operator. -pub enum BinaryOperator { - /// Arithmetic Plus - Plus, - /// Arithmetic Minus - Minus, - /// Arithmetic Multiply - Multiply, - /// Arithmetic Divide - Divide, +pub enum BinaryPredicateOp { /// Comparison Less Than LessThan, /// Comparison Less Than Or Equal @@ -48,127 +54,66 @@ pub enum BinaryOperator { NotIn, } -impl BinaryOperator { - /// True if this is a comparison for which NULL input always produces NULL output - pub(crate) fn is_null_intolerant_comparison(&self) -> bool { - use BinaryOperator::*; - match self { - Plus | Minus | Multiply | Divide => false, // not a comparison - LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => true, - Equal | NotEqual => true, - Distinct | In | NotIn => false, // tolerates NULL input - } - } - - /// Returns `` (if any) such that `B A` is equivalent to `A B`. - pub(crate) fn commute(&self) -> Option { - use BinaryOperator::*; - match self { - GreaterThan => Some(LessThan), - GreaterThanOrEqual => Some(LessThanOrEqual), - LessThan => Some(GreaterThan), - LessThanOrEqual => Some(GreaterThanOrEqual), - Equal | NotEqual | Distinct | Plus | Multiply => Some(*self), - In | NotIn | Minus | Divide => None, // not commutative - } - } +/// A binary expression operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinaryExpressionOp { + /// Arithmetic Plus + Plus, + /// Arithmetic Minus + Minus, + /// Arithmetic Multiply + Multiply, + /// Arithmetic Divide + Divide, } +/// A junction (AND/OR) predicate operator. #[derive(Debug, Clone, Copy, PartialEq)] -pub enum JunctionOperator { +pub enum JunctionPredicateOp { /// Conjunction And, /// Disjunction Or, } -impl JunctionOperator { - pub(crate) fn invert(&self) -> JunctionOperator { - use JunctionOperator::*; - match self { - And => Or, - Or => And, - } - } -} - -impl Display for BinaryOperator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use BinaryOperator::*; - match self { - Plus => write!(f, "+"), - Minus => write!(f, "-"), - Multiply => write!(f, "*"), - Divide => write!(f, "/"), - LessThan => write!(f, "<"), - LessThanOrEqual => write!(f, "<="), - GreaterThan => write!(f, ">"), - GreaterThanOrEqual => write!(f, ">="), - Equal => write!(f, "="), - NotEqual => write!(f, "!="), - // TODO(roeap): AFAIK DISTINCT does not have a commonly used operator symbol - // so ideally this would not be used as we use Display for rendering expressions - // in our code we take care of this, but theirs might not ... - Distinct => write!(f, "DISTINCT"), - In => write!(f, "IN"), - NotIn => write!(f, "NOT IN"), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq)] -/// A unary operator. -pub enum UnaryOperator { - /// Unary Not - Not, - /// Unary Is Null - IsNull, -} - -pub type ExpressionRef = std::sync::Arc; +//////////////////////////////////////////////////////////////////////// +// Expressions and predicates +//////////////////////////////////////////////////////////////////////// #[derive(Clone, Debug, PartialEq)] -pub struct UnaryExpression { +pub struct UnaryPredicate { /// The operator. - pub op: UnaryOperator, - /// The expression. + pub op: UnaryPredicateOp, + /// The input expression. pub expr: Box, } -impl UnaryExpression { - fn new(op: UnaryOperator, expr: impl Into) -> Self { - let expr = Box::new(expr.into()); - Self { op, expr } - } + +#[derive(Clone, Debug, PartialEq)] +pub struct BinaryPredicate { + /// The operator. + pub op: BinaryPredicateOp, + /// The left-hand side of the operation. + pub left: Box, + /// The right-hand side of the operation. + pub right: Box, } #[derive(Clone, Debug, PartialEq)] pub struct BinaryExpression { /// The operator. - pub op: BinaryOperator, + pub op: BinaryExpressionOp, /// The left-hand side of the operation. pub left: Box, /// The right-hand side of the operation. pub right: Box, } -impl BinaryExpression { - fn new(op: BinaryOperator, left: impl Into, right: impl Into) -> Self { - let left = Box::new(left.into()); - let right = Box::new(right.into()); - Self { op, left, right } - } -} #[derive(Clone, Debug, PartialEq)] -pub struct JunctionExpression { +pub struct JunctionPredicate { /// The operator. - pub op: JunctionOperator, - /// The expressions. - pub exprs: Vec, -} -impl JunctionExpression { - fn new(op: JunctionOperator, exprs: Vec) -> Self { - Self { op, exprs } - } + pub op: JunctionPredicateOp, + /// The input predicates. + pub preds: Vec, } /// A SQL expression. @@ -182,67 +127,119 @@ pub enum Expression { Literal(Scalar), /// A column reference by name. Column(ColumnName), + /// A predicate treated as a boolean expression + Predicate(Box), /// A struct computed from a Vec of expressions Struct(Vec), + /// An expression that takes two expressions as input. + Binary(BinaryExpression), +} + +/// A SQL predicate. +/// +/// These predicates do not track or validate data types, other than the type +/// of literals. It is up to the predicate evaluator to validate the +/// predicate against a schema and add appropriate casts as required. +#[derive(Debug, Clone, PartialEq)] +pub enum Predicate { + /// A boolean-valued expression, useful for e.g. `AND(, )`. + BooleanExpression(Expression), + /// Boolean inversion (true <-> false) + /// + /// NOTE: NOT is not a normal unary predicate, because it requires a predicate as input (not an + /// expression), and is never directly evaluated. Instead, observing that all predicates are + /// invertible, NOT is always pushed down into its child predicate, inverting it. For example, + /// `NOT (a < b)` pushes down and inverts `<` to `>=`, producing `a >= b`. + Not(Box), /// A unary operation. - Unary(UnaryExpression), + Unary(UnaryPredicate), /// A binary operation. - Binary(BinaryExpression), + Binary(BinaryPredicate), /// A junction operation (AND/OR). - Junction(JunctionExpression), - // TODO: support more expressions, such as IS IN, LIKE, etc. + Junction(JunctionPredicate), } -impl> From for Expression { - fn from(value: T) -> Self { - Self::literal(value) +//////////////////////////////////////////////////////////////////////// +// Struct/Enum impls +//////////////////////////////////////////////////////////////////////// + +impl BinaryPredicateOp { + /// True if this is a comparison for which NULL input always produces NULL output + pub(crate) fn is_null_intolerant(&self) -> bool { + use BinaryPredicateOp::*; + match self { + LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => true, + Equal | NotEqual => true, + Distinct | In | NotIn => false, // tolerates NULL input + } } -} -impl From for Expression { - fn from(value: ColumnName) -> Self { - Self::Column(value) + /// Returns `` (if any) such that `B A` is equivalent to `A B`. + pub(crate) fn commute(&self) -> Option { + use BinaryPredicateOp::*; + match self { + GreaterThan => Some(LessThan), + GreaterThanOrEqual => Some(LessThanOrEqual), + LessThan => Some(GreaterThan), + LessThanOrEqual => Some(GreaterThanOrEqual), + Equal | NotEqual | Distinct => Some(*self), + In | NotIn => None, // not commutative + } } } -impl Display for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use Expression::*; +impl JunctionPredicateOp { + pub(crate) fn invert(&self) -> JunctionPredicateOp { + use JunctionPredicateOp::*; match self { - Literal(l) => write!(f, "{l}"), - Column(name) => write!(f, "Column({name})"), - Struct(exprs) => write!( - f, - "Struct({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ), - Binary(BinaryExpression { - op: BinaryOperator::Distinct, - left, - right, - }) => write!(f, "DISTINCT({left}, {right})"), - Binary(BinaryExpression { op, left, right }) => write!(f, "{left} {op} {right}"), - Unary(UnaryExpression { op, expr }) => match op { - UnaryOperator::Not => write!(f, "NOT {expr}"), - UnaryOperator::IsNull => write!(f, "{expr} IS NULL"), - }, - Junction(JunctionExpression { op, exprs }) => { - let exprs = &exprs.iter().map(|e| format!("{e}")).join(", "); - let op = match op { - JunctionOperator::And => "AND", - JunctionOperator::Or => "OR", - }; - write!(f, "{op}({exprs})") - } + And => Or, + Or => And, } } } +impl UnaryPredicate { + fn new(op: UnaryPredicateOp, expr: impl Into) -> Self { + let expr = Box::new(expr.into()); + Self { op, expr } + } +} + +impl BinaryExpression { + fn new( + op: BinaryExpressionOp, + left: impl Into, + right: impl Into, + ) -> Self { + let left = Box::new(left.into()); + let right = Box::new(right.into()); + Self { op, left, right } + } +} + +impl BinaryPredicate { + fn new( + op: BinaryPredicateOp, + left: impl Into, + right: impl Into, + ) -> Self { + let left = Box::new(left.into()); + let right = Box::new(right.into()); + Self { op, left, right } + } +} + +impl JunctionPredicate { + fn new(op: JunctionPredicateOp, preds: Vec) -> Self { + Self { op, preds } + } +} + impl Expression { /// Returns a set of columns referenced by this expression. pub fn references(&self) -> HashSet<&ColumnName> { let mut references = GetColumnReferences::default(); - let _ = references.transform(self); + let _ = references.transform_expr(self); references.into_inner() } @@ -264,20 +261,67 @@ impl Expression { Self::Literal(Scalar::Null(data_type)) } + /// Wraps a predicate as a boolean-valued expression + pub fn predicate(value: Predicate) -> Self { + match value { + Predicate::BooleanExpression(expr) => expr, + _ => Self::Predicate(value.into()), + } + } + /// Create a new struct expression pub fn struct_from(exprs: impl IntoIterator) -> Self { Self::Struct(exprs.into_iter().collect()) } - /// Creates a new unary expression OP expr - pub fn unary(op: UnaryOperator, expr: impl Into) -> Self { - let expr = Box::new(expr.into()); - Self::Unary(UnaryExpression { op, expr }) + /// Create a new predicate `self IS NULL` + pub fn is_null(self) -> Predicate { + Predicate::is_null(self) + } + + /// Create a new predicate `self IS NOT NULL` + pub fn is_not_null(self) -> Predicate { + Predicate::is_not_null(self) + } + + /// Create a new predicate `self == other` + pub fn eq(self, other: impl Into) -> Predicate { + Predicate::eq(self, other) + } + + /// Create a new predicate `self != other` + pub fn ne(self, other: impl Into) -> Predicate { + Predicate::ne(self, other) + } + + /// Create a new predicate `self <= other` + pub fn le(self, other: impl Into) -> Predicate { + Predicate::le(self, other) + } + + /// Create a new predicate `self < other` + pub fn lt(self, other: impl Into) -> Predicate { + Predicate::lt(self, other) + } + + /// Create a new predicate `self >= other` + pub fn ge(self, other: impl Into) -> Predicate { + Predicate::ge(self, other) + } + + /// Create a new predicate `self > other` + pub fn gt(self, other: impl Into) -> Predicate { + Predicate::gt(self, other) + } + + /// Create a new predicate `DISTINCT(self, other)` + pub fn distinct(self, other: impl Into) -> Predicate { + Predicate::distinct(self, other) } /// Creates a new binary expression lhs OP rhs pub fn binary( - op: BinaryOperator, + op: BinaryExpressionOp, lhs: impl Into, rhs: impl Into, ) -> Self { @@ -287,427 +331,290 @@ impl Expression { right: Box::new(rhs.into()), }) } +} + +impl Predicate { + /// Returns a set of columns referenced by this predicate. + pub fn references(&self) -> HashSet<&ColumnName> { + let mut references = GetColumnReferences::default(); + let _ = references.transform_pred(self); + references.into_inner() + } - /// Creates a new junction expression OP(exprs...) - pub fn junction(op: JunctionOperator, exprs: impl IntoIterator) -> Self { - let exprs = exprs.into_iter().collect(); - Self::Junction(JunctionExpression { op, exprs }) + /// Creates a new boolean column reference. See also [`Expression::column`]. + pub fn column(field_names: impl IntoIterator) -> Predicate + where + ColumnName: FromIterator, + { + Self::from_expr(ColumnName::new(field_names)) } - /// Creates a new expression AND(exprs...) - pub fn and_from(exprs: impl IntoIterator) -> Self { - Self::junction(JunctionOperator::And, exprs) + /// Create a new literal boolean value + pub const fn literal(value: bool) -> Self { + Self::BooleanExpression(Expression::Literal(Scalar::Boolean(value))) } - /// Creates a new expression OR(exprs...) - pub fn or_from(exprs: impl IntoIterator) -> Self { - Self::junction(JunctionOperator::Or, exprs) + /// Creates a NULL literal boolean value + pub const fn null_literal() -> Self { + Self::BooleanExpression(Expression::Literal(Scalar::Null(DataType::BOOLEAN))) } - /// Logical NOT (boolean inversion) - pub fn not(expr: impl Into) -> Self { - Self::unary(UnaryOperator::Not, expr.into()) + /// Converts a boolean-valued expression into a predicate + pub fn from_expr(expr: impl Into) -> Self { + match expr.into() { + Expression::Predicate(p) => *p, + expr => Predicate::BooleanExpression(expr), + } } - /// Create a new expression `self IS NULL` - pub fn is_null(self) -> Self { - Self::unary(UnaryOperator::IsNull, self) + /// Logical NOT (boolean inversion) + pub fn not(pred: impl Into) -> Self { + Self::Not(pred.into().into()) } - /// Create a new expression `self IS NOT NULL` - pub fn is_not_null(self) -> Self { - Self::not(Self::is_null(self)) + /// Create a new predicate `self IS NULL` + pub fn is_null(expr: impl Into) -> Predicate { + Self::unary(UnaryPredicateOp::IsNull, expr) } - /// Create a new expression `self == other` - pub fn eq(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::Equal, self, other) + /// Create a new predicate `self IS NOT NULL` + pub fn is_not_null(expr: impl Into) -> Predicate { + Self::not(Self::is_null(expr)) } - /// Create a new expression `self != other` - pub fn ne(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::NotEqual, self, other) + /// Create a new predicate `self == other` + pub fn eq(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::Equal, a, b) } - /// Create a new expression `self <= other` - pub fn le(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::LessThanOrEqual, self, other) + /// Create a new predicate `self != other` + pub fn ne(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::NotEqual, a, b) } - /// Create a new expression `self < other` - pub fn lt(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::LessThan, self, other) + /// Create a new predicate `self <= other` + pub fn le(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::LessThanOrEqual, a, b) } - /// Create a new expression `self >= other` - pub fn ge(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::GreaterThanOrEqual, self, other) + /// Create a new predicate `self < other` + pub fn lt(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::LessThan, a, b) } - /// Create a new expression `self > other` - pub fn gt(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::GreaterThan, self, other) + /// Create a new predicate `self >= other` + pub fn ge(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::GreaterThanOrEqual, a, b) } - /// Create a new expression `self >= other` - pub fn gt_eq(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::GreaterThanOrEqual, self, other) + /// Create a new predicate `self > other` + pub fn gt(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::GreaterThan, a, b) } - /// Create a new expression `self <= other` - pub fn lt_eq(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::LessThanOrEqual, self, other) + /// Create a new predicate `DISTINCT(self, other)` + pub fn distinct(a: impl Into, b: impl Into) -> Self { + Self::binary(BinaryPredicateOp::Distinct, a, b) } - /// Create a new expression `self AND other` + /// Create a new predicate `self AND other` pub fn and(a: impl Into, b: impl Into) -> Self { Self::and_from([a.into(), b.into()]) } - /// Create a new expression `self OR other` + /// Create a new predicate `self OR other` pub fn or(a: impl Into, b: impl Into) -> Self { Self::or_from([a.into(), b.into()]) } - /// Create a new expression `DISTINCT(self, other)` - pub fn distinct(self, other: impl Into) -> Self { - Self::binary(BinaryOperator::Distinct, self, other) + /// Creates a new predicate AND(preds...) + pub fn and_from(preds: impl IntoIterator) -> Self { + Self::junction(JunctionPredicateOp::And, preds) } -} -/// Generic framework for recursive bottom-up expression transforms. Transformations return -/// `Option` with the following semantics: -/// -/// * `Some(Cow::Owned)` -- The input was transformed and the parent should be updated with it. -/// * `Some(Cow::Borrowed)` -- The input was not transformed. -/// * `None` -- The input was filtered out and the parent should be updated to not reference it. -/// -/// The transform can start from the generic [`Self::transform`], or directly from a specific -/// expression variant (e.g. [`Self::transform_binary`] to start with [`BinaryExpression`]). -/// -/// The provided `transform_xxx` methods all default to no-op (returning their input as -/// `Some(Cow::Borrowed)`), and implementations should selectively override specific `transform_xxx` -/// methods as needed for the task at hand. -/// -/// The provided `recurse_into_xxx` methods encapsulate the boilerplate work of recursing into the -/// children of each expression variant. Implementations can call these as needed but will generally -/// not need to override them. -pub trait ExpressionTransform<'a> { - /// Called for each literal encountered during the expression traversal. - fn transform_literal(&mut self, value: &'a Scalar) -> Option> { - Some(Cow::Borrowed(value)) - } - - /// Called for each column reference encountered during the expression traversal. - fn transform_column(&mut self, name: &'a ColumnName) -> Option> { - Some(Cow::Borrowed(name)) - } - - /// Called for the expression list of each [`Expression::Struct`] encountered during the - /// traversal. Implementations can call [`Self::recurse_into_struct`] if they wish to - /// recursively transform child expressions. - fn transform_struct( - &mut self, - fields: &'a Vec, - ) -> Option>> { - self.recurse_into_struct(fields) - } - - /// Called for each [`UnaryExpression`] encountered during the traversal. Implementations can - /// call [`Self::recurse_into_unary`] if they wish to recursively transform the child. - fn transform_unary(&mut self, expr: &'a UnaryExpression) -> Option> { - self.recurse_into_unary(expr) - } - - /// Called for each [`BinaryExpression`] encountered during the traversal. Implementations can - /// call [`Self::recurse_into_binary`] if they wish to recursively transform the children. - fn transform_binary( - &mut self, - expr: &'a BinaryExpression, - ) -> Option> { - self.recurse_into_binary(expr) - } - - /// Called for each [`JunctionExpression`] encountered during the traversal. Implementations can - /// call [`Self::recurse_into_junction`] if they wish to recursively transform the children. - fn transform_junction( - &mut self, - expr: &'a JunctionExpression, - ) -> Option> { - self.recurse_into_junction(expr) - } - - /// General entry point for transforming an expression. This method will dispatch to the - /// specific transform for each expression variant. Also invoked internally in order to recurse - /// on the child(ren) of non-leaf variants. - fn transform(&mut self, expr: &'a Expression) -> Option> { - use Cow::*; - let expr = match expr { - Expression::Literal(s) => match self.transform_literal(s)? { - Owned(s) => Owned(Expression::Literal(s)), - Borrowed(_) => Borrowed(expr), - }, - Expression::Column(c) => match self.transform_column(c)? { - Owned(c) => Owned(Expression::Column(c)), - Borrowed(_) => Borrowed(expr), - }, - Expression::Struct(s) => match self.transform_struct(s)? { - Owned(s) => Owned(Expression::Struct(s)), - Borrowed(_) => Borrowed(expr), - }, - Expression::Unary(u) => match self.transform_unary(u)? { - Owned(u) => Owned(Expression::Unary(u)), - Borrowed(_) => Borrowed(expr), - }, - Expression::Binary(b) => match self.transform_binary(b)? { - Owned(b) => Owned(Expression::Binary(b)), - Borrowed(_) => Borrowed(expr), - }, - Expression::Junction(j) => match self.transform_junction(j)? { - Owned(j) => Owned(Expression::Junction(j)), - Borrowed(_) => Borrowed(expr), - }, - }; - Some(expr) - } - - /// Recursively transforms a struct's child expressions. Returns `None` if all children were - /// removed, `Some(Cow::Owned)` if at least one child was changed or removed, and - /// `Some(Cow::Borrowed)` otherwise. - fn recurse_into_struct( - &mut self, - fields: &'a Vec, - ) -> Option>> { - let mut num_borrowed = 0; - let new_fields: Vec<_> = fields - .iter() - .filter_map(|f| self.transform(f)) - .inspect(|f| { - if matches!(f, Cow::Borrowed(_)) { - num_borrowed += 1; - } - }) - .collect(); - - if new_fields.is_empty() { - None // all fields filtered out - } else if num_borrowed < fields.len() { - // At least one field was changed or filtered out, so make a new field list - let fields = new_fields.into_iter().map(|f| f.into_owned()).collect(); - Some(Cow::Owned(fields)) - } else { - Some(Cow::Borrowed(fields)) - } + /// Creates a new predicate OR(preds...) + pub fn or_from(preds: impl IntoIterator) -> Self { + Self::junction(JunctionPredicateOp::Or, preds) } - /// Recursively transforms a unary expression's child. Returns `None` if the child was removed, - /// `Some(Cow::Owned)` if the child was changed, and `Some(Cow::Borrowed)` otherwise. - fn recurse_into_unary(&mut self, u: &'a UnaryExpression) -> Option> { - use Cow::*; - let u = match self.transform(&u.expr)? { - Owned(expr) => Owned(UnaryExpression::new(u.op, expr)), - Borrowed(_) => Borrowed(u), - }; - Some(u) - } - - /// Recursively transforms a binary expression's children. Returns `None` if at least one child - /// was removed, `Some(Cow::Owned)` if at least one child changed, and `Some(Cow::Borrowed)` - /// otherwise. - fn recurse_into_binary( - &mut self, - b: &'a BinaryExpression, - ) -> Option> { - use Cow::*; - let left = self.transform(&b.left)?; - let right = self.transform(&b.right)?; - let b = match (&left, &right) { - (Borrowed(_), Borrowed(_)) => Borrowed(b), - _ => Owned(BinaryExpression::new( - b.op, - left.into_owned(), - right.into_owned(), - )), - }; - Some(b) - } - - /// Recursively transforms a junction expression's children. Returns `None` if all children were - /// removed, `Some(Cow::Owned)` if at least one child was changed or removed, and - /// `Some(Cow::Borrowed)` otherwise. - fn recurse_into_junction( - &mut self, - j: &'a JunctionExpression, - ) -> Option> { - use Cow::*; - let j = match self.recurse_into_struct(&j.exprs)? { - Owned(exprs) => Owned(JunctionExpression::new(j.op, exprs)), - Borrowed(_) => Borrowed(j), - }; - Some(j) + /// Creates a new unary predicate OP expr + pub fn unary(op: UnaryPredicateOp, expr: impl Into) -> Self { + let expr = Box::new(expr.into()); + Self::Unary(UnaryPredicate { op, expr }) } -} -impl> std::ops::Add for Expression { - type Output = Self; + /// Creates a new binary predicate lhs OP rhs + pub fn binary( + op: BinaryPredicateOp, + lhs: impl Into, + rhs: impl Into, + ) -> Self { + Self::Binary(BinaryPredicate { + op, + left: Box::new(lhs.into()), + right: Box::new(rhs.into()), + }) + } - fn add(self, rhs: R) -> Self::Output { - Self::binary(BinaryOperator::Plus, self, rhs) + /// Creates a new junction predicate OP(preds...) + pub fn junction(op: JunctionPredicateOp, preds: impl IntoIterator) -> Self { + let preds = preds.into_iter().collect(); + Self::Junction(JunctionPredicate { op, preds }) } } -impl> std::ops::Sub for Expression { - type Output = Self; +//////////////////////////////////////////////////////////////////////// +// Trait impls +//////////////////////////////////////////////////////////////////////// - fn sub(self, rhs: R) -> Self { - Self::binary(BinaryOperator::Minus, self, rhs) +impl Display for BinaryExpressionOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BinaryExpressionOp::*; + match self { + Plus => write!(f, "+"), + Minus => write!(f, "-"), + Multiply => write!(f, "*"), + Divide => write!(f, "/"), + } } } -impl> std::ops::Mul for Expression { - type Output = Self; - - fn mul(self, rhs: R) -> Self { - Self::binary(BinaryOperator::Multiply, self, rhs) +impl Display for BinaryPredicateOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BinaryPredicateOp::*; + match self { + LessThan => write!(f, "<"), + LessThanOrEqual => write!(f, "<="), + GreaterThan => write!(f, ">"), + GreaterThanOrEqual => write!(f, ">="), + Equal => write!(f, "="), + NotEqual => write!(f, "!="), + // TODO(roeap): AFAIK DISTINCT does not have a commonly used operator symbol + // so ideally this would not be used as we use Display for rendering expressions + // in our code we take care of this, but theirs might not ... + Distinct => write!(f, "DISTINCT"), + In => write!(f, "IN"), + NotIn => write!(f, "NOT IN"), + } } } -impl> std::ops::Div for Expression { - type Output = Self; - - fn div(self, rhs: R) -> Self { - Self::binary(BinaryOperator::Divide, self, rhs) +impl Display for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use Expression::*; + match self { + Literal(l) => write!(f, "{l}"), + Column(name) => write!(f, "Column({name})"), + Predicate(p) => write!(f, "{p}"), + Struct(exprs) => write!( + f, + "Struct({})", + &exprs.iter().map(|e| format!("{e}")).join(", ") + ), + Binary(BinaryExpression { op, left, right }) => write!(f, "{left} {op} {right}"), + } } } -/// Retrieves the set of column names referenced by an expression. -#[derive(Default)] -pub(crate) struct GetColumnReferences<'a> { - references: HashSet<&'a ColumnName>, +impl Display for Predicate { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use Predicate::*; + match self { + BooleanExpression(expr) => write!(f, "{expr}"), + Not(pred) => write!(f, "NOT({pred})"), + Binary(BinaryPredicate { + op: BinaryPredicateOp::Distinct, + left, + right, + }) => write!(f, "DISTINCT({left}, {right})"), + Binary(BinaryPredicate { op, left, right }) => write!(f, "{left} {op} {right}"), + Unary(UnaryPredicate { op, expr }) => match op { + UnaryPredicateOp::IsNull => write!(f, "{expr} IS NULL"), + }, + Junction(JunctionPredicate { op, preds }) => { + let preds = &preds.iter().map(|p| format!("{p}")).join(", "); + let op = match op { + JunctionPredicateOp::And => "AND", + JunctionPredicateOp::Or => "OR", + }; + write!(f, "{op}({preds})") + } + } + } } -impl<'a> GetColumnReferences<'a> { - pub(crate) fn into_inner(self) -> HashSet<&'a ColumnName> { - self.references +impl From for Expression { + fn from(value: Scalar) -> Self { + Self::literal(value) } } -impl<'a> ExpressionTransform<'a> for GetColumnReferences<'a> { - fn transform_column(&mut self, name: &'a ColumnName) -> Option> { - self.references.insert(name); - Some(Cow::Borrowed(name)) +impl From for Expression { + fn from(value: ColumnName) -> Self { + Self::Column(value) } } -/// An expression "transform" that doesn't actually change the expression at all. Instead, it -/// measures the maximum depth of a expression, with a depth limit to prevent stack overflow. Useful -/// for verifying that a expression has reasonable depth before attempting to work with it. -pub struct ExpressionDepthChecker { - depth_limit: usize, - max_depth_seen: usize, - current_depth: usize, - call_count: usize, +impl From for Expression { + fn from(value: Predicate) -> Self { + Self::predicate(value) + } } -impl ExpressionDepthChecker { - /// Depth-checks the given expression against a given depth limit. The return value is the - /// largest depth seen, which is capped at one more than the depth limit (indicating the - /// recursion was terminated). - pub fn check(expr: &Expression, depth_limit: usize) -> usize { - Self::check_with_call_count(expr, depth_limit).0 - } - - // Exposed for testing - fn check_with_call_count(expr: &Expression, depth_limit: usize) -> (usize, usize) { - let mut checker = Self { - depth_limit, - max_depth_seen: 0, - current_depth: 0, - call_count: 0, - }; - checker.transform(expr); - (checker.max_depth_seen, checker.call_count) - } - - // Triggers the requested recursion only doing so would not exceed the depth limit. - fn depth_limited<'a, T: Clone + std::fmt::Debug>( - &mut self, - recurse: impl FnOnce(&mut Self, &'a T) -> Option>, - arg: &'a T, - ) -> Option> { - self.call_count += 1; - if self.max_depth_seen < self.current_depth { - self.max_depth_seen = self.current_depth; - if self.depth_limit < self.current_depth { - tracing::warn!( - "Max expression depth {} exceeded by {arg:?}", - self.depth_limit - ); - } - } - if self.max_depth_seen <= self.depth_limit { - self.current_depth += 1; - let _ = recurse(self, arg); - self.current_depth -= 1; - } - None + +impl From for Predicate { + fn from(value: ColumnName) -> Self { + Self::from_expr(value) } } -impl<'a> ExpressionTransform<'a> for ExpressionDepthChecker { - fn transform_struct( - &mut self, - fields: &'a Vec, - ) -> Option>> { - self.depth_limited(Self::recurse_into_struct, fields) + +impl> std::ops::Add for Expression { + type Output = Self; + + fn add(self, rhs: R) -> Self::Output { + Self::binary(BinaryExpressionOp::Plus, self, rhs) } +} + +impl> std::ops::Sub for Expression { + type Output = Self; - fn transform_unary(&mut self, expr: &'a UnaryExpression) -> Option> { - self.depth_limited(Self::recurse_into_unary, expr) + fn sub(self, rhs: R) -> Self { + Self::binary(BinaryExpressionOp::Minus, self, rhs) } +} + +impl> std::ops::Mul for Expression { + type Output = Self; - fn transform_binary( - &mut self, - expr: &'a BinaryExpression, - ) -> Option> { - self.depth_limited(Self::recurse_into_binary, expr) + fn mul(self, rhs: R) -> Self { + Self::binary(BinaryExpressionOp::Multiply, self, rhs) } +} + +impl> std::ops::Div for Expression { + type Output = Self; - fn transform_junction( - &mut self, - expr: &'a JunctionExpression, - ) -> Option> { - self.depth_limited(Self::recurse_into_junction, expr) + fn div(self, rhs: R) -> Self { + Self::binary(BinaryExpressionOp::Divide, self, rhs) } } #[cfg(test)] mod tests { - use super::{column_expr, Expression as Expr, ExpressionDepthChecker}; + use super::{column_expr, column_pred, Expression as Expr, Predicate as Pred}; #[test] fn test_expression_format() { - let col_ref = column_expr!("x"); let cases = [ - (col_ref.clone(), "Column(x)"), - (col_ref.clone().eq(2), "Column(x) = 2"), - ((col_ref.clone() - 4).lt(10), "Column(x) - 4 < 10"), - ((col_ref.clone() + 4) / 10 * 42, "Column(x) + 4 / 10 * 42"), + (column_expr!("x"), "Column(x)"), ( - Expr::and(col_ref.clone().gt_eq(2), col_ref.clone().lt_eq(10)), - "AND(Column(x) >= 2, Column(x) <= 10)", + (column_expr!("x") + Expr::literal(4)) / Expr::literal(10) * Expr::literal(42), + "Column(x) + 4 / 10 * 42", ), ( - Expr::and_from([ - col_ref.clone().gt_eq(2), - col_ref.clone().lt_eq(10), - col_ref.clone().lt_eq(100), - ]), - "AND(Column(x) >= 2, Column(x) <= 10, Column(x) <= 100)", + Expr::struct_from([column_expr!("x"), Expr::literal(2), Expr::literal(10)]), + "Struct(Column(x), 2, 10)", ), - ( - Expr::or(col_ref.clone().gt(2), col_ref.clone().lt(10)), - "OR(Column(x) > 2, Column(x) < 10)", - ), - (col_ref.eq("foo"), "Column(x) = 'foo'"), ]; for (expr, expected) in cases { @@ -717,87 +624,45 @@ mod tests { } #[test] - fn test_depth_checker() { - let expr = Expr::and_from([ - Expr::struct_from([ - Expr::and_from([ - Expr::lt(Expr::literal(10), column_expr!("x")), - Expr::or_from([Expr::literal(true), column_expr!("b")]), - ]), - Expr::literal(true), - Expr::not(Expr::literal(true)), - ]), - Expr::and_from([ - Expr::not(column_expr!("b")), - Expr::gt(Expr::literal(10), column_expr!("x")), - Expr::or_from([ - Expr::and_from([Expr::not(Expr::literal(true)), Expr::literal(10)]), - Expr::literal(10), + fn test_predicate_format() { + let cases = [ + (column_pred!("x"), "Column(x)"), + (column_expr!("x").eq(Expr::literal(2)), "Column(x) = 2"), + ( + (column_expr!("x") - Expr::literal(4)).lt(Expr::literal(10)), + "Column(x) - 4 < 10", + ), + ( + Pred::and( + column_expr!("x").ge(Expr::literal(2)), + column_expr!("x").le(Expr::literal(10)), + ), + "AND(Column(x) >= 2, Column(x) <= 10)", + ), + ( + Pred::and_from([ + column_expr!("x").ge(Expr::literal(2)), + column_expr!("x").le(Expr::literal(10)), + column_expr!("x").le(Expr::literal(100)), ]), - Expr::literal(true), - ]), - Expr::ne( - Expr::literal(true), - Expr::and_from([Expr::literal(true), column_expr!("b")]), + "AND(Column(x) >= 2, Column(x) <= 10, Column(x) <= 100)", + ), + ( + Pred::or( + column_expr!("x").gt(Expr::literal(2)), + column_expr!("x").lt(Expr::literal(10)), + ), + "OR(Column(x) > 2, Column(x) < 10)", + ), + ( + column_expr!("x").eq(Expr::literal("foo")), + "Column(x) = 'foo'", ), - ]); - - // Similar to ExpressionDepthChecker::check, but also returns call count - let check_with_call_count = - |depth_limit| ExpressionDepthChecker::check_with_call_count(&expr, depth_limit); - - // NOTE: The checker ignores leaf nodes! - - // AND - // * STRUCT - // * AND >LIMIT< - // * NOT - // * AND - // * NE - assert_eq!(check_with_call_count(1), (2, 6)); - - // AND - // * STRUCT - // * AND - // * LT >LIMIT< - // * OR - // * NOT - // * AND - // * NE - assert_eq!(check_with_call_count(2), (3, 8)); - - // AND - // * STRUCT - // * AND - // * LT - // * OR - // * NOT - // * AND - // * NOT - // * GT - // * OR - // * AND - // * NOT >LIMIT< - // * NE - assert_eq!(check_with_call_count(3), (4, 13)); - - // Depth limit not hit (full traversal required) - - // AND - // * STRUCT - // * AND - // * LT - // * OR - // * NOT - // * AND - // * NOT - // * GT - // * OR - // * AND - // * NOT - // * NE - // * AND - assert_eq!(check_with_call_count(4), (4, 14)); - assert_eq!(check_with_call_count(5), (4, 14)); + ]; + + for (pred, expected) in cases { + let result = format!("{}", pred); + assert_eq!(result, expected); + } } } diff --git a/kernel/src/expressions/scalars.rs b/kernel/src/expressions/scalars.rs index 7eda73f7d..e660ff87e 100644 --- a/kernel/src/expressions/scalars.rs +++ b/kernel/src/expressions/scalars.rs @@ -503,8 +503,8 @@ impl PrimitiveType { mod tests { use std::f32::consts::PI; - use crate::expressions::{column_expr, BinaryOperator}; - use crate::Expression; + use crate::expressions::{column_expr, BinaryPredicateOp}; + use crate::{Expression as Expr, Predicate as Pred}; use super::*; @@ -596,10 +596,10 @@ mod tests { }); let column = column_expr!("item"); - let array_op = Expression::binary(BinaryOperator::In, 10, array.clone()); - let array_not_op = Expression::binary(BinaryOperator::NotIn, 10, array); - let column_op = Expression::binary(BinaryOperator::In, PI, column.clone()); - let column_not_op = Expression::binary(BinaryOperator::NotIn, "Cool", column); + let array_op = Pred::binary(BinaryPredicateOp::In, Expr::literal(10), array.clone()); + let array_not_op = Pred::binary(BinaryPredicateOp::NotIn, Expr::literal(10), array); + let column_op = Pred::binary(BinaryPredicateOp::In, Expr::literal(PI), column.clone()); + let column_not_op = Pred::binary(BinaryPredicateOp::NotIn, Expr::literal("Cool"), column); assert_eq!(&format!("{}", array_op), "10 IN (1, 2, 3)"); assert_eq!(&format!("{}", array_not_op), "10 NOT IN (1, 2, 3)"); assert_eq!(&format!("{}", column_op), "3.1415927 IN Column(item)"); diff --git a/kernel/src/expressions/transforms.rs b/kernel/src/expressions/transforms.rs new file mode 100644 index 000000000..521f67439 --- /dev/null +++ b/kernel/src/expressions/transforms.rs @@ -0,0 +1,542 @@ +use std::borrow::Cow; +use std::collections::HashSet; + +use crate::expressions::{ + BinaryExpression, BinaryPredicate, ColumnName, Expression, JunctionPredicate, Predicate, + Scalar, UnaryPredicate, +}; + +/// Generic framework for recursive bottom-up transforms of expressions and +/// predicates. Transformations return `Option` with the following semantics: +/// +/// * `Some(Cow::Owned)` -- The input was transformed and the parent should be updated with it. +/// * `Some(Cow::Borrowed)` -- The input was not transformed. +/// * `None` -- The input was filtered out and the parent should be updated to not reference it. +/// +/// The transform can start from the generic [`Self::transform_expr`] or [`Self::transform_pred`'], +/// or directly from a specific expression/predicate variant (e.g. [`Self::transform_expr_column`] +/// for [`ColumnName`], [`Self::transform_pred_unary`] for [`UnaryPredicate`]). +/// +/// The provided `transform_xxx` methods all default to no-op (returning their input as +/// `Some(Cow::Borrowed)`), and implementations should selectively override specific `transform_xxx` +/// methods as needed for the task at hand. +/// +/// The provided `recurse_into_xxx` methods encapsulate the boilerplate work of recursing into the +/// children of each expression or predicate variant. Implementations can call these as needed but +/// will generally not need to override them. +pub trait ExpressionTransform<'a> { + /// Called for each literal encountered during the expression traversal. + fn transform_expr_literal(&mut self, value: &'a Scalar) -> Option> { + Some(Cow::Borrowed(value)) + } + + /// Called for each column reference encountered during the expression traversal. + fn transform_expr_column(&mut self, name: &'a ColumnName) -> Option> { + Some(Cow::Borrowed(name)) + } + + /// Called for the expression list of each [`Expression::Struct`] encountered during the + /// traversal. Implementations can call [`Self::recurse_into_expr_struct`] if they wish to + /// recursively transform the child expressions. + fn transform_expr_struct( + &mut self, + fields: &'a Vec, + ) -> Option>> { + self.recurse_into_expr_struct(fields) + } + + /// Called for the child predicate of each [`Expression::Predicate`] encountered during the + /// traversal. Implementations can call [`Self::recurse_into_expr_pred`] if they wish to + /// recursively transform the child predicate. + fn transform_expr_pred(&mut self, pred: &'a Predicate) -> Option> { + self.recurse_into_expr_pred(pred) + } + + /// Called for the child predicate of each [`Predicate::Not`] encountered during the + /// traversal. Implementations can call [`Self::recurse_into_pred_not`] if they wish to + /// recursively transform the child expression. + fn transform_pred_not(&mut self, pred: &'a Predicate) -> Option> { + self.recurse_into_pred_not(pred) + } + + /// Called for each [`UnaryPredicate`] encountered during the traversal. Implementations can + /// call [`Self::recurse_into_pred_unary`] if they wish to recursively transform the child. + fn transform_pred_unary( + &mut self, + pred: &'a UnaryPredicate, + ) -> Option> { + self.recurse_into_pred_unary(pred) + } + + /// Called for each [`BinaryExpression`] encountered during the traversal. Implementations can + /// call [`Self::recurse_into_expr_binary`] if they wish to recursively transform the children. + fn transform_expr_binary( + &mut self, + expr: &'a BinaryExpression, + ) -> Option> { + self.recurse_into_expr_binary(expr) + } + + /// Called for each [`BinaryPredicate`] encountered during the traversal. Implementations can + /// call [`Self::recurse_into_pred_binary`] if they wish to recursively transform the children. + fn transform_pred_binary( + &mut self, + pred: &'a BinaryPredicate, + ) -> Option> { + self.recurse_into_pred_binary(pred) + } + + /// Called for each [`JunctionPredicate`] encountered during the traversal. Implementations can + /// call [`Self::recurse_into_pred_junction`] if they wish to recursively transform the children. + fn transform_pred_junction( + &mut self, + pred: &'a JunctionPredicate, + ) -> Option> { + self.recurse_into_pred_junction(pred) + } + + /// General entry point for transforming an expression. This method will dispatch to the + /// specific transform for each expression variant. Also invoked internally in order to recurse + /// on the child(ren) of non-leaf variants. + fn transform_expr(&mut self, expr: &'a Expression) -> Option> { + use Cow::*; + let expr = match expr { + Expression::Literal(s) => match self.transform_expr_literal(s)? { + Owned(s) => Owned(Expression::Literal(s)), + Borrowed(_) => Borrowed(expr), + }, + Expression::Column(c) => match self.transform_expr_column(c)? { + Owned(c) => Owned(Expression::Column(c)), + Borrowed(_) => Borrowed(expr), + }, + Expression::Predicate(p) => match self.transform_expr_pred(p)? { + Owned(p) => Owned(p.into()), + Borrowed(_) => Borrowed(expr), + }, + Expression::Struct(s) => match self.transform_expr_struct(s)? { + Owned(s) => Owned(Expression::Struct(s)), + Borrowed(_) => Borrowed(expr), + }, + Expression::Binary(b) => match self.transform_expr_binary(b)? { + Owned(b) => Owned(Expression::Binary(b)), + Borrowed(_) => Borrowed(expr), + }, + }; + Some(expr) + } + + /// General entry point for transforming a predicate. This method will dispatch to the specific + /// transform for each predicate variant. Also invoked internally in order to recurse on the + /// child(ren) of non-leaf variants. + fn transform_pred(&mut self, pred: &'a Predicate) -> Option> { + use Cow::*; + let pred = match pred { + Predicate::BooleanExpression(e) => match self.transform_expr(e)? { + Owned(e) => Owned(Predicate::BooleanExpression(e)), + Borrowed(_) => Borrowed(pred), + }, + Predicate::Not(p) => match self.transform_pred_not(p)? { + Owned(p) => Owned(Predicate::not(p)), + Borrowed(_) => Borrowed(pred), + }, + Predicate::Unary(u) => match self.transform_pred_unary(u)? { + Owned(u) => Owned(Predicate::Unary(u)), + Borrowed(_) => Borrowed(pred), + }, + Predicate::Binary(b) => match self.transform_pred_binary(b)? { + Owned(b) => Owned(Predicate::Binary(b)), + Borrowed(_) => Borrowed(pred), + }, + Predicate::Junction(j) => match self.transform_pred_junction(j)? { + Owned(j) => Owned(Predicate::Junction(j)), + Borrowed(_) => Borrowed(pred), + }, + }; + Some(pred) + } + + /// Recursively transforms a struct's child expressions. Returns `None` if all children were + /// removed, `Some(Cow::Owned)` if at least one child was changed or removed, and + /// `Some(Cow::Borrowed)` otherwise. + fn recurse_into_expr_struct( + &mut self, + fields: &'a Vec, + ) -> Option>> { + recurse_into_children(fields, |f| self.transform_expr(f)) + } + + /// Recursively transforms the child of an [`Expression::Predicate`]. Returns `None` if all + /// children were removed, `Some(Cow::Owned)` if at least one child was changed or removed, and + /// `Some(Cow::Borrowed)` otherwise. + fn recurse_into_expr_pred(&mut self, pred: &'a Predicate) -> Option> { + self.transform_pred(pred) + } + + /// Recursively transforms the child of a [`Predicate::Not`] expression. Returns `None` if the + /// child was removed, `Some(Cow::Owned)` if the child was changed, and `Some(Cow::Borrowed)` + /// otherwise. + fn recurse_into_pred_not(&mut self, p: &'a Predicate) -> Option> { + use Cow::*; + let p = match self.transform_pred(p)? { + Owned(pred) => Owned(Predicate::not(pred)), + Borrowed(_) => Borrowed(p), + }; + Some(p) + } + + /// Recursively transforms a unary predicate's child. Returns `None` if the child was removed, + /// `Some(Cow::Owned)` if the child was changed, and `Some(Cow::Borrowed)` otherwise. + fn recurse_into_pred_unary( + &mut self, + u: &'a UnaryPredicate, + ) -> Option> { + use Cow::*; + let u = match self.transform_expr(&u.expr)? { + Owned(expr) => Owned(UnaryPredicate::new(u.op, expr)), + Borrowed(_) => Borrowed(u), + }; + Some(u) + } + + /// Recursively transforms a binary predicate's children. Returns `None` if at least one child + /// was removed, `Some(Cow::Owned)` if at least one child changed, and `Some(Cow::Borrowed)` + /// otherwise. + fn recurse_into_pred_binary( + &mut self, + b: &'a BinaryPredicate, + ) -> Option> { + use Cow::*; + let left = self.transform_expr(&b.left)?; + let right = self.transform_expr(&b.right)?; + let b = match (&left, &right) { + (Borrowed(_), Borrowed(_)) => Borrowed(b), + _ => Owned(BinaryPredicate::new( + b.op, + left.into_owned(), + right.into_owned(), + )), + }; + Some(b) + } + + /// Recursively transforms a binary expression's children. Returns `None` if at least one child + /// was removed, `Some(Cow::Owned)` if at least one child changed, and `Some(Cow::Borrowed)` + /// otherwise. + fn recurse_into_expr_binary( + &mut self, + b: &'a BinaryExpression, + ) -> Option> { + use Cow::*; + let left = self.transform_expr(&b.left)?; + let right = self.transform_expr(&b.right)?; + let b = match (&left, &right) { + (Borrowed(_), Borrowed(_)) => Borrowed(b), + _ => Owned(BinaryExpression::new( + b.op, + left.into_owned(), + right.into_owned(), + )), + }; + Some(b) + } + + /// Recursively transforms a junction predicate's children. Returns `None` if all children were + /// removed, `Some(Cow::Owned)` if at least one child was changed or removed, and + /// `Some(Cow::Borrowed)` otherwise. + fn recurse_into_pred_junction( + &mut self, + j: &'a JunctionPredicate, + ) -> Option> { + use Cow::*; + let j = match recurse_into_children(&j.preds, |p| self.transform_pred(p))? { + Owned(preds) => Owned(JunctionPredicate::new(j.op, preds)), + Borrowed(_) => Borrowed(j), + }; + Some(j) + } +} + +/// Used to recurse into the children of an `Expression::Struct` or `Predicate::Junction`. +fn recurse_into_children<'a, T: Clone>( + children: &'a Vec, + recurse_fn: impl FnMut(&'a T) -> Option>, +) -> Option>> { + let mut num_borrowed = 0; + let new_children: Vec<_> = children + .iter() + .filter_map(recurse_fn) + .inspect(|f| { + if matches!(f, Cow::Borrowed(_)) { + num_borrowed += 1; + } + }) + .collect(); + + if new_children.is_empty() { + None // all fields filtered out + } else if num_borrowed < children.len() { + // At least one field was changed or removed, so make a new field list + let children = new_children.into_iter().map(Cow::into_owned).collect(); + Some(Cow::Owned(children)) + } else { + Some(Cow::Borrowed(children)) + } +} + +/// Retrieves the set of column names referenced by an expression. +#[derive(Default)] +pub(crate) struct GetColumnReferences<'a> { + references: HashSet<&'a ColumnName>, +} + +impl<'a> GetColumnReferences<'a> { + pub(crate) fn into_inner(self) -> HashSet<&'a ColumnName> { + self.references + } +} + +impl<'a> ExpressionTransform<'a> for GetColumnReferences<'a> { + fn transform_expr_column(&mut self, name: &'a ColumnName) -> Option> { + self.references.insert(name); + Some(Cow::Borrowed(name)) + } +} + +/// An expression "transform" that doesn't actually change the expression at all. Instead, it +/// measures the maximum depth of a expression, with a depth limit to prevent stack overflow. Useful +/// for verifying that a expression has reasonable depth before attempting to work with it. +pub struct ExpressionDepthChecker { + depth_limit: usize, + max_depth_seen: usize, + current_depth: usize, + call_count: usize, +} + +impl ExpressionDepthChecker { + /// Depth-checks the given expression against a given depth limit. The return value is the + /// largest depth seen, which is capped at one more than the depth limit (indicating the + /// recursion was terminated). + pub fn check_expr(expr: &Expression, depth_limit: usize) -> usize { + Self::check_expr_with_call_count(expr, depth_limit).0 + } + + /// Depth-checks the given predicate against a given depth limit. The return value is the + /// largest depth seen, which is capped at one more than the depth limit (indicating the + /// recursion was terminated). + pub fn check_pred(pred: &Predicate, depth_limit: usize) -> usize { + Self::check_pred_with_call_count(pred, depth_limit).0 + } + + // Exposed for testing + fn check_expr_with_call_count(expr: &Expression, depth_limit: usize) -> (usize, usize) { + let mut checker = Self::new(depth_limit); + checker.transform_expr(expr); + (checker.max_depth_seen, checker.call_count) + } + + // Exposed for testing + fn check_pred_with_call_count(pred: &Predicate, depth_limit: usize) -> (usize, usize) { + let mut checker = Self::new(depth_limit); + checker.transform_pred(pred); + (checker.max_depth_seen, checker.call_count) + } + + fn new(depth_limit: usize) -> Self { + Self { + depth_limit, + max_depth_seen: 0, + current_depth: 0, + call_count: 0, + } + } + + // Triggers the requested recursion only doing so would not exceed the depth limit. + fn depth_limited<'a, T: Clone + std::fmt::Debug>( + &mut self, + recurse: impl FnOnce(&mut Self, &'a T) -> Option>, + arg: &'a T, + ) -> Option> { + self.call_count += 1; + if self.max_depth_seen < self.current_depth { + self.max_depth_seen = self.current_depth; + if self.depth_limit < self.current_depth { + tracing::warn!( + "Max expression depth {} exceeded by {arg:?}", + self.depth_limit + ); + } + } + if self.max_depth_seen <= self.depth_limit { + self.current_depth += 1; + let _ = recurse(self, arg); + self.current_depth -= 1; + } + None + } +} + +impl<'a> ExpressionTransform<'a> for ExpressionDepthChecker { + fn transform_expr_struct( + &mut self, + fields: &'a Vec, + ) -> Option>> { + self.depth_limited(Self::recurse_into_expr_struct, fields) + } + + fn transform_expr_pred(&mut self, pred: &'a Predicate) -> Option> { + self.depth_limited(Self::recurse_into_expr_pred, pred) + } + + fn transform_pred_not(&mut self, pred: &'a Predicate) -> Option> { + self.depth_limited(Self::recurse_into_pred_not, pred) + } + + fn transform_pred_unary( + &mut self, + pred: &'a UnaryPredicate, + ) -> Option> { + self.depth_limited(Self::recurse_into_pred_unary, pred) + } + + fn transform_expr_binary( + &mut self, + expr: &'a BinaryExpression, + ) -> Option> { + self.depth_limited(Self::recurse_into_expr_binary, expr) + } + + fn transform_pred_binary( + &mut self, + pred: &'a BinaryPredicate, + ) -> Option> { + self.depth_limited(Self::recurse_into_pred_binary, pred) + } + + fn transform_pred_junction( + &mut self, + pred: &'a JunctionPredicate, + ) -> Option> { + self.depth_limited(Self::recurse_into_pred_junction, pred) + } +} + +#[cfg(test)] +mod tests { + use super::ExpressionDepthChecker; + use crate::expressions::{column_expr, column_pred, Expression as Expr, Predicate as Pred}; + + #[test] + fn test_depth_checker() { + let pred = Pred::or_from([ + Pred::and_from([ + Pred::or( + Pred::lt(Expr::literal(10), column_expr!("x")), + Pred::gt(Expr::literal(20), column_expr!("b")), + ), + Pred::literal(true), + Pred::not(Pred::literal(true)), + ]), + Pred::and_from([ + Pred::is_null(column_expr!("b")), + Pred::gt(Expr::literal(10), column_expr!("x")), + Pred::or( + Pred::gt(Expr::literal(5) + Expr::literal(10), Expr::literal(20)), + column_pred!("y"), + ), + Pred::literal(true), + ]), + Pred::ne( + Expr::literal(42), + Expr::struct_from([Expr::literal(10), column_expr!("b")]), + ), + ]); + + // Similar to ExpressionDepthChecker::check_pred, but also returns call count + let check_with_call_count = + |depth_limit| ExpressionDepthChecker::check_pred_with_call_count(&pred, depth_limit); + + // NOTE: The checker ignores leaf nodes! + + // OR + // * AND + // * OR >LIMIT< + // * NOT + // * AND + // * NE + assert_eq!(check_with_call_count(1), (2, 6)); + + // OR + // * AND + // * OR + // * LT >LIMIT< + // * GT + // * NOT + // * AND + // * NE + assert_eq!(check_with_call_count(2), (3, 8)); + + // OR + // * AND + // * OR + // * LT + // * GT + // * NOT + // * AND + // * IS NULL + // * GT + // * OR + // * GT + // * PLUS >LIMIT< + // * NE + assert_eq!(check_with_call_count(3), (4, 13)); + + // Depth limit not hit (full traversal required) + // + // OR + // * AND + // * OR + // * LT + // * GT + // * NOT + // * AND + // * IS_NULL + // * GT + // * OR + // * GT + // * PLUS + // * NE + // * STRUCT + assert_eq!(check_with_call_count(4), (4, 14)); + assert_eq!(check_with_call_count(5), (4, 14)); + + // Check expressions as well + let expr = Expr::from(pred); + let check_with_call_count = + |depth_limit| ExpressionDepthChecker::check_expr_with_call_count(&expr, depth_limit); + + // Adding an `Expression::Predicate` root makes the expression tree exactly one node taller, + // which makes the recursion terminate sooner than previously: + // + // PRED + // * OR + // * AND > LIMIT 1 < + // * OR > LIMIT 2 < + // * LT > LIMIT 3 < + // * GT + // * NOT + // * AND + // * IS_NULL + // * GT + // * OR + // * GT + // * PLUS > LIMIT 4 < + // * NE + // * STRUCT + assert_eq!(check_with_call_count(1), (2, 5)); + assert_eq!(check_with_call_count(2), (3, 7)); + assert_eq!(check_with_call_count(3), (4, 9)); + assert_eq!(check_with_call_count(4), (5, 14)); + assert_eq!(check_with_call_count(5), (5, 15)); + assert_eq!(check_with_call_count(6), (5, 15)); + } +} diff --git a/kernel/src/kernel_predicates/mod.rs b/kernel/src/kernel_predicates/mod.rs index 60936a64a..38a619163 100644 --- a/kernel/src/kernel_predicates/mod.rs +++ b/kernel/src/kernel_predicates/mod.rs @@ -1,6 +1,6 @@ use crate::expressions::{ - BinaryExpression, BinaryOperator, ColumnName, Expression as Expr, JunctionExpression, - JunctionOperator, Scalar, UnaryExpression, UnaryOperator, + BinaryPredicate, BinaryPredicateOp, ColumnName, Expression as Expr, JunctionPredicate, + JunctionPredicateOp, Predicate as Pred, Scalar, UnaryPredicate, UnaryPredicateOp, }; use crate::schema::DataType; @@ -12,15 +12,15 @@ pub(crate) mod parquet_stats_skipping; #[cfg(test)] mod tests; -/// Uses kernel (not engine) logic to evaluate an expression tree against column names that resolve -/// as scalars. Useful for testing/debugging but also serves as a reference implementation that +/// Uses kernel (not engine) logic to evaluate a predicate tree against column names that resolve as +/// scalars. Useful for testing/debugging but also serves as a reference implementation that /// documents the expression semantics that kernel relies on for data skipping. /// /// # Inverted expression semantics /// /// Because inversion (`NOT` operator) has special semantics and can often be optimized away by /// pushing it down, most methods take an `inverted` flag. That allows operations like -/// [`UnaryOperator::Not`] to simply evaluate their operand with a flipped `inverted` flag, and +/// [`UnaryPredicateOp::Not`] to simply evaluate their operand with a flipped `inverted` flag, and /// greatly simplifies the implementations of most operators (other than those which have to /// directly implement NOT semantics, which are unavoidably complex in that regard). /// @@ -30,11 +30,10 @@ mod tests; /// example, [`crate::engine::parquet_stats_skipping::ParquetStatsProvider`] directly evaluates the /// predicate over parquet footer stats and returns boolean results, while /// [`crate::scan::data_skipping::DataSkippingPredicateCreator`] instead transforms the input -/// predicate expression to a data skipping predicate expresion that the engine can evaluated -/// directly against Delta data skipping stats during log replay. Although this approach is harder -/// to read and reason about at first, the majority of expressions can be implemented generically, -/// which greatly reduces redundancy and ensures that all flavors of predicate evaluation have the -/// same semantics. +/// predicate to a data skipping predicate that the engine can evaluated directly against Delta data +/// skipping stats during log replay. Although this approach is harder to read and reason about at +/// first, the majority of predicates can be implemented generically, which greatly reduces +/// redundancy and ensures that all flavors of predicate evaluation have the same semantics. /// /// # NULL and error semantics /// @@ -48,93 +47,118 @@ mod tests; /// rely on nullcount stats for their work (NULL/missing nullcount stats makes them output NULL). /// /// For safety reasons, NULL-checking operations only accept literal and column inputs where -/// stats-based skipping is well-defined. If an arbitrary data skipping expression evaluates to -/// NULL, there is no way to tell whether the original expression really evaluated to NULL (safe to +/// stats-based skipping is well-defined. If an arbitrary data skipping predicate evaluates to +/// NULL, there is no way to tell whether the original predicate really evaluated to NULL (safe to /// use), or the data skipping version evaluated to NULL due to missing stats (very unsafe to use). /// /// NOTE: The error-handling semantics of this trait's scalar-based predicate evaluation may differ -/// from those of the engine's expression evaluation, because kernel expressions don't include the +/// from those of the engine's predicate evaluation, because kernel predicates don't include the /// necessary type information to reliably detect all type errors. pub(crate) trait KernelPredicateEvaluator { - type Output; - - /// A (possibly inverted) scalar NULL test, e.g. ` IS [NOT] NULL`. - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option; + type Predicate: Into; + type Expression; /// A (possibly inverted) boolean scalar value, e.g. `[NOT] `. - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option; + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option; + + /// A (possibly inverted) scalar NULL test, e.g. ` IS [NOT] NULL`. + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option; /// A (possibly inverted) NULL check, e.g. ` IS [NOT] NULL`. - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option; + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option; /// A (possibly inverted) less-than comparison, e.g. ` < `. - fn eval_lt(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option; + fn eval_pred_lt( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option; /// A (possibly inverted) less-than-or-equal comparison, e.g. ` <= ` - fn eval_le(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option; + fn eval_pred_le( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option; /// A (possibly inverted) equality comparison, e.g. ` = ` or ` != `. /// /// NOTE: Caller is responsible to commute the operation if needed, e.g. ` != ` /// becomes ` != `. - fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option; + fn eval_pred_eq( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option; /// A (possibly inverted) comparison between two scalars, e.g. ` != `. - fn eval_binary_scalars( + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, - ) -> Option; + ) -> Option; /// A (possibly inverted) comparison between two columns, e.g. ` != `. - fn eval_binary_columns( + fn eval_pred_binary_columns( &self, - op: BinaryOperator, + op: BinaryPredicateOp, a: &ColumnName, b: &ColumnName, inverted: bool, - ) -> Option; + ) -> Option; - /// Completes evaluation of a (possibly inverted) junction expression. + /// Completes evaluation of a (possibly inverted) junction predicate. /// /// AND and OR are implemented by first evaluating its (possibly inverted) inputs. This part is - /// always the same, provided by [`eval_junction`]). The results are then combined to become the - /// expression's output in some implementation-defined way (this method). - fn finish_eval_junction( + /// always the same, provided by [`eval_pred_junction`]). The results are then combined to become the + /// predicate's output in some implementation-defined way (this method). + fn finish_eval_pred_junction( &self, - op: JunctionOperator, - exprs: impl IntoIterator>, + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, - ) -> Option; + ) -> Option; // ==================== PROVIDED METHODS ==================== - /// A (possibly inverted) boolean column access, e.g. `[NOT] `. - fn eval_column(&self, col: &ColumnName, inverted: bool) -> Option { - // The expression is equivalent to != FALSE, and the expression NOT is - // equivalent to != TRUE. - self.eval_eq(col, &Scalar::from(inverted), true) + /// Dispatches a (possibly inverted) NOT predicate + fn eval_pred_not(&self, pred: &Pred, inverted: bool) -> Option { + self.eval_pred(pred, !inverted) } - /// Dispatches a (possibly inverted) NOT expression - fn eval_not(&self, expr: &Expr, inverted: bool) -> Option { - self.eval_expr(expr, !inverted) + /// Dispatches a (possibly inverted) boolean expression used as a predicate + fn eval_pred_expr(&self, expr: &Expr, inverted: bool) -> Option { + // Directly evaluate literals and and predicates used as expressions. Evaluate columns as + // ` == TRUE`. All other expressions unsupported. + match expr { + Expr::Literal(val) => self.eval_pred_scalar(val, inverted), + Expr::Column(col) => self.eval_pred_eq(col, &Scalar::from(true), inverted), + Expr::Predicate(pred) => self.eval_pred(pred, inverted), + _ => None, + } } /// Dispatches a (possibly inverted) unary expression to each operator's specific implementation. - fn eval_unary(&self, op: UnaryOperator, expr: &Expr, inverted: bool) -> Option { + fn eval_pred_unary( + &self, + op: UnaryPredicateOp, + expr: &Expr, + inverted: bool, + ) -> Option { match op { - UnaryOperator::Not => self.eval_not(expr, inverted), - UnaryOperator::IsNull => match expr { + UnaryPredicateOp::IsNull => match expr { // WARNING: Only literals and columns can be safely null-checked. Attempting to // null-check an expressions such as `a < 10` could wrongly produce FALSE in case // `a` is just plain missing (rather than known to be NULL. A missing-value can // arise e.g. if data skipping encounters a column with missing stats, or if // partition pruning encounters a non-partition column. - Expr::Literal(val) => self.eval_scalar_is_null(val, inverted), - Expr::Column(col) => self.eval_is_null(col, inverted), + Expr::Literal(val) => self.eval_pred_scalar_is_null(val, inverted), + Expr::Column(col) => self.eval_pred_is_null(col, inverted), _ => { debug!("Unsupported operand: IS [NOT] NULL: {expr:?}"); None @@ -148,48 +172,53 @@ pub(crate) trait KernelPredicateEvaluator { /// /// 1. DISTINCT(, NULL) is equivalent to ` IS NOT NULL` /// 2. DISTINCT(, ) is equivalent to `OR( IS NULL, != )` - fn eval_distinct( + fn eval_pred_distinct( &self, col: &ColumnName, val: &Scalar, inverted: bool, - ) -> Option { + ) -> Option { if let Scalar::Null(_) = val { - self.eval_is_null(col, !inverted) + self.eval_pred_is_null(col, !inverted) } else { let args = [ - self.eval_is_null(col, inverted), - self.eval_eq(col, val, !inverted), + self.eval_pred_is_null(col, inverted), + self.eval_pred_eq(col, val, !inverted), ]; - self.finish_eval_junction(JunctionOperator::Or, args, inverted) + self.finish_eval_pred_junction(JunctionPredicateOp::Or, args, inverted) } } /// A (possibly inverted) IN-list check, e.g. ` [NOT] IN `. /// /// Unsupported by default, but implementations can override it if they wish. - fn eval_in(&self, _col: &ColumnName, _val: &Scalar, _inverted: bool) -> Option { + fn eval_pred_in( + &self, + _col: &ColumnName, + _val: &Scalar, + _inverted: bool, + ) -> Option { None // TODO? } /// Dispatches a (possibly inverted) binary expression to each operator's specific implementation. /// /// NOTE: Only binary operators that produce boolean outputs are supported. - fn eval_binary( + fn eval_pred_binary( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Expr, right: &Expr, inverted: bool, - ) -> Option { - use BinaryOperator::*; + ) -> Option { + use BinaryPredicateOp::*; use Expr::{Column, Literal}; // NOTE: We rely on the literal values to provide logical type hints. That means we cannot // perform column-column comparisons, because we cannot infer the logical type to use. let (op, col, val) = match (left, right) { - (Column(a), Column(b)) => return self.eval_binary_columns(op, a, b, inverted), - (Literal(a), Literal(b)) => return self.eval_binary_scalars(op, a, b, inverted), + (Column(a), Column(b)) => return self.eval_pred_binary_columns(op, a, b, inverted), + (Literal(a), Literal(b)) => return self.eval_pred_binary_scalars(op, a, b, inverted), (Literal(val), Column(col)) => (op.commute()?, col, val), (Column(col), Literal(val)) => (op, col, val), _ => { @@ -198,53 +227,52 @@ pub(crate) trait KernelPredicateEvaluator { } }; match op { - Plus | Minus | Multiply | Divide => None, // Unsupported - not boolean output - LessThan => self.eval_lt(col, val, inverted), - GreaterThanOrEqual => self.eval_lt(col, val, !inverted), - LessThanOrEqual => self.eval_le(col, val, inverted), - GreaterThan => self.eval_le(col, val, !inverted), - Equal => self.eval_eq(col, val, inverted), - NotEqual => self.eval_eq(col, val, !inverted), - Distinct => self.eval_distinct(col, val, inverted), - In => self.eval_in(col, val, inverted), - NotIn => self.eval_in(col, val, !inverted), + LessThan => self.eval_pred_lt(col, val, inverted), + GreaterThanOrEqual => self.eval_pred_lt(col, val, !inverted), + LessThanOrEqual => self.eval_pred_le(col, val, inverted), + GreaterThan => self.eval_pred_le(col, val, !inverted), + Equal => self.eval_pred_eq(col, val, inverted), + NotEqual => self.eval_pred_eq(col, val, !inverted), + Distinct => self.eval_pred_distinct(col, val, inverted), + In => self.eval_pred_in(col, val, inverted), + NotIn => self.eval_pred_in(col, val, !inverted), } } - /// Dispatches a junction operation, leveraging each implementation's [`finish_eval_junction`]. - fn eval_junction( + /// Dispatches a predicate junction operation (AND or OR), leveraging each implementation's + /// [`finish_eval_junction`]. + fn eval_pred_junction( &self, - op: JunctionOperator, - exprs: &[Expr], + op: JunctionPredicateOp, + preds: &[Pred], inverted: bool, - ) -> Option { - let exprs = exprs.iter().map(|expr| self.eval_expr(expr, inverted)); - self.finish_eval_junction(op, exprs, inverted) - } - - /// Dispatches an expression to the specific implementation for each expression variant. - /// - /// NOTE: [`Expression::Struct`] is not supported and always evaluates to `None`. - fn eval_expr(&self, expr: &Expr, inverted: bool) -> Option { - use Expr::*; - match expr { - Literal(val) => self.eval_scalar(val, inverted), - Column(col) => self.eval_column(col, inverted), - Struct(_) => None, // not supported - Unary(UnaryExpression { op, expr }) => self.eval_unary(*op, expr, inverted), - Binary(BinaryExpression { op, left, right }) => { - self.eval_binary(*op, left, right, inverted) + ) -> Option { + let preds = preds.iter().map(|pred| self.eval_pred(pred, inverted)); + self.finish_eval_pred_junction(op, preds, inverted) + } + + /// Dispatches a predicate to the specific implementation for each predicate variant. + fn eval_pred(&self, pred: &Pred, inverted: bool) -> Option { + use Pred::*; + match pred { + BooleanExpression(expr) => self.eval_pred_expr(expr, inverted), + Not(pred) => self.eval_pred_not(pred, inverted), + Unary(UnaryPredicate { op, expr }) => self.eval_pred_unary(*op, expr, inverted), + Binary(BinaryPredicate { op, left, right }) => { + self.eval_pred_binary(*op, left, right, inverted) + } + Junction(JunctionPredicate { op, preds }) => { + self.eval_pred_junction(*op, preds, inverted) } - Junction(JunctionExpression { op, exprs }) => self.eval_junction(*op, exprs, inverted), } } /// Evaluates a (possibly inverted) predicate with SQL WHERE semantics. /// - /// By default, [`eval_expr`] behaves badly for comparisons involving NULL columns (e.g. `a < + /// By default, [`eval_pred`] behaves badly for comparisons involving NULL columns (e.g. `a < /// 10` when `a` is NULL), because the comparison correctly evaluates to NULL, but NULL - /// expressions are interpreted as "stats missing" (= cannot skip). This ambiguity can "poison" - /// the entire expression, causing it to return NULL instead of FALSE that would allow skipping: + /// values are interpreted as "stats missing" (= cannot skip). This ambiguity can "poison" + /// the entire predicate, causing it to return NULL instead of FALSE that would allow skipping: /// /// ```text /// WHERE a < 10 -- NULL (can't skip file) @@ -271,7 +299,7 @@ pub(crate) trait KernelPredicateEvaluator { /// ``` /// /// HOWEVER, we cannot safely NULL-check the result of an arbitrary data skipping predicate - /// because an expression will also produce NULL if the value is just plain missing (e.g. data + /// because a predicate will also produce NULL if the value is just plain missing (e.g. data /// skipping over a column that lacks stats), and if that NULL should propagate all the way to /// top-level, it would be wrongly interpreted as FALSE (= skippable). /// @@ -310,11 +338,11 @@ pub(crate) trait KernelPredicateEvaluator { /// /// Any time the push-down reaches an operator that does not support push-down (such as OR), we /// simply drop the NULL check. This way, the top-level NULL check only applies to - /// sub-expressions that can safely implement it, while ignoring other sub-expressions. The - /// unsupported sub-expressions could produce nulls at runtime that prevent skipping, but false + /// sub-predicates that can safely implement it, while ignoring other sub-predicates. The + /// unsupported sub-predicates could produce nulls at runtime that prevent skipping, but false /// positives are OK -- the query will still correctly filter out the unwanted rows that result. /// - /// At expression evaluation time, a NULL value of `a` (from our example) would evaluate as: + /// At predicate evaluation time, a NULL value of `a` (from our example) would evaluate as: /// /// ```text /// AND(..., AND(a IS NOT NULL, 10 IS NOT NULL, a < 10), ...) @@ -341,43 +369,43 @@ pub(crate) trait KernelPredicateEvaluator { /// /// WARNING: Not an idempotent transform. If data skipping eval produces a sql predicate, /// evaluating the result with sql semantics has undefined behavior. - fn eval_expr_sql_where(&self, filter: &Expr, inverted: bool) -> Option { - use Expr::*; - match filter { - Junction(JunctionExpression { op, exprs }) => { - // Recursively invoke `eval_expr_sql_where` instead of the usual `eval_expr` for AND/OR. - let exprs = exprs + fn eval_pred_sql_where(&self, pred: &Pred, inverted: bool) -> Option { + use Pred::*; + match pred { + Not(pred) => self.eval_pred_sql_where(pred, !inverted), + BooleanExpression(expr) => match expr { + Expr::Literal(val) if val.is_null() => { + // AND(NULL IS NOT NULL, NULL) = AND(FALSE, NULL) = FALSE + self.eval_pred_scalar(&Scalar::from(false), false) + } + Expr::Column(col) => { + let preds = [ + self.eval_pred_unary(UnaryPredicateOp::IsNull, expr, true), + self.eval_pred_eq(col, &Scalar::from(true), inverted), + ]; + self.finish_eval_pred_junction(JunctionPredicateOp::And, preds, false) + } + Expr::Predicate(pred) => self.eval_pred_sql_where(pred, inverted), + _ => None, + }, + Junction(JunctionPredicate { op, preds }) => { + // Recursively invoke `eval_pred_sql_where` instead of the usual `eval_pred` for AND/OR. + let preds = preds .iter() - .map(|expr| self.eval_expr_sql_where(expr, inverted)); - self.finish_eval_junction(*op, exprs, inverted) - } - Binary(BinaryExpression { op, left, right }) if op.is_null_intolerant_comparison() => { - // Perform a nullsafe comparison instead of the usual `eval_binary` - let exprs = [ - self.eval_unary(UnaryOperator::IsNull, left, true), - self.eval_unary(UnaryOperator::IsNull, right, true), - self.eval_binary(*op, left, right, inverted), - ]; - self.finish_eval_junction(JunctionOperator::And, exprs, false) + .map(|pred| self.eval_pred_sql_where(pred, inverted)); + self.finish_eval_pred_junction(*op, preds, inverted) } - Unary(UnaryExpression { - op: UnaryOperator::Not, - expr, - }) => self.eval_expr_sql_where(expr, !inverted), - Column(col) => { - // Perform a nullsafe comparison instead of the usual `eval_column` - let exprs = [ - self.eval_unary(UnaryOperator::IsNull, filter, true), - self.eval_column(col, inverted), + Binary(BinaryPredicate { op, left, right }) if op.is_null_intolerant() => { + // Perform a nullsafe comparison instead of the usual `eval_pred_binary` + let preds = [ + self.eval_pred_unary(UnaryPredicateOp::IsNull, left, true), + self.eval_pred_unary(UnaryPredicateOp::IsNull, right, true), + self.eval_pred_binary(*op, left, right, inverted), ]; - self.finish_eval_junction(JunctionOperator::And, exprs, false) - } - Literal(val) if val.is_null() => { - // AND(NULL IS NOT NULL, NULL) = AND(FALSE, NULL) = FALSE - self.eval_scalar(&Scalar::from(false), false) + self.finish_eval_pred_junction(JunctionPredicateOp::And, preds, false) } - // Process all remaining expressions normally, because they are not proven safe. Indeed, - // expressions like DISTINCT and IS [NOT] NULL are known-unsafe under SQL semantics: + // Process all remaining predicates normally, because they are not proven safe. Indeed, + // predicates like DISTINCT and IS [NOT] NULL are known-unsafe under SQL semantics: // // ``` // x IS NULL # when x really is NULL @@ -396,19 +424,19 @@ pub(crate) trait KernelPredicateEvaluator { // = FALSE // ``` // - _ => self.eval_expr(filter, inverted), + _ => self.eval_pred(pred, inverted), } } - /// A convenient non-inverted wrapper for [`eval_expr`] + /// A convenient non-inverted wrapper for [`eval_pred`] #[cfg(test)] - fn eval(&self, expr: &Expr) -> Option { - self.eval_expr(expr, false) + fn eval(&self, pred: &Pred) -> Option { + self.eval_pred(pred, false) } - /// A convenient non-inverted wrapper for [`eval_expr_sql_where`]. - fn eval_sql_where(&self, expr: &Expr) -> Option { - self.eval_expr_sql_where(expr, false) + /// A convenient non-inverted wrapper for [`eval_pred_sql_where`]. + fn eval_sql_where(&self, pred: &Pred) -> Option { + self.eval_pred_sql_where(pred, false) } } @@ -416,19 +444,19 @@ pub(crate) trait KernelPredicateEvaluator { /// reuse by multiple bool-output predicate evaluator implementations. pub(crate) struct KernelPredicateEvaluatorDefaults; impl KernelPredicateEvaluatorDefaults { - /// Directly null-tests a scalar. See [`KernelPredicateEvaluator::eval_scalar_is_null`]. - pub(crate) fn eval_scalar_is_null(val: &Scalar, inverted: bool) -> Option { - Some(val.is_null() != inverted) - } - - /// Directly evaluates a boolean scalar. See [`KernelPredicateEvaluator::eval_scalar`]. - pub(crate) fn eval_scalar(val: &Scalar, inverted: bool) -> Option { + /// Directly evaluates a boolean scalar. See [`KernelPredicateEvaluator::eval_pred_scalar`]. + pub(crate) fn eval_pred_scalar(val: &Scalar, inverted: bool) -> Option { match val { Scalar::Boolean(val) => Some(*val != inverted), _ => None, } } + /// Directly null-tests a scalar. See [`KernelPredicateEvaluator::eval_pred_scalar_is_null`]. + pub(crate) fn eval_pred_scalar_is_null(val: &Scalar, inverted: bool) -> Option { + Some(val.is_null() != inverted) + } + /// A (possibly inverted) partial comparison of two scalars, leveraging the [`PartialOrd`] /// trait. pub(crate) fn partial_cmp_scalars( @@ -442,14 +470,14 @@ impl KernelPredicateEvaluatorDefaults { Some(matched != inverted) } - /// Directly evaluates a boolean comparison. See [`KernelPredicateEvaluator::eval_binary_scalars`]. - pub(crate) fn eval_binary_scalars( - op: BinaryOperator, + /// Directly evaluates a boolean comparison. See [`KernelPredicateEvaluator::eval_pred_binary_scalars`]. + pub(crate) fn eval_pred_binary_scalars( + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, ) -> Option { - use BinaryOperator::*; + use BinaryPredicateOp::*; match op { Equal => Self::partial_cmp_scalars(Ordering::Equal, left, right, inverted), NotEqual => Self::partial_cmp_scalars(Ordering::Equal, left, right, !inverted), @@ -465,23 +493,23 @@ impl KernelPredicateEvaluatorDefaults { } /// Finishes evaluating a (possibly inverted) junction operation. See - /// [`KernelPredicateEvaluator::finish_eval_junction`]. + /// [`KernelPredicateEvaluator::finish_eval_pred_junction`]. /// /// The inputs were already inverted by the caller, if needed. /// /// With AND (OR), any FALSE (TRUE) input dominates, forcing a FALSE (TRUE) output. If there /// was no dominating input, then any NULL input forces NULL output. Otherwise, return the /// non-dominant value. Inverting the operation also inverts the dominant value. - pub(crate) fn finish_eval_junction( - op: JunctionOperator, - exprs: impl IntoIterator>, + pub(crate) fn finish_eval_pred_junction( + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, ) -> Option { let dominator = match op { - JunctionOperator::And => inverted, - JunctionOperator::Or => !inverted, + JunctionPredicateOp::And => inverted, + JunctionPredicateOp::Or => !inverted, }; - let result = exprs.into_iter().try_fold(false, |found_null, val| { + let result = preds.into_iter().try_fold(false, |found_null, val| { match val { Some(val) if val == dominator => None, // (1) short circuit, dominant found Some(_) => Some(found_null), @@ -548,65 +576,66 @@ impl From for DefaultKernelPredicateEvalu /// to convert column references to scalars, and evaluates the resulting constant expression to /// produce a boolean output. impl KernelPredicateEvaluator for DefaultKernelPredicateEvaluator { - type Output = bool; + type Predicate = bool; + type Expression = Scalar; - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar_is_null(val, inverted) + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar(val, inverted) } - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar(val, inverted) + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar_is_null(val, inverted) } - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option { let col = self.resolve_column(col)?; - self.eval_scalar_is_null(&col, inverted) + self.eval_pred_scalar_is_null(&col, inverted) } - fn eval_lt(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + fn eval_pred_lt(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { let col = self.resolve_column(col)?; - self.eval_binary_scalars(BinaryOperator::LessThan, &col, val, inverted) + self.eval_pred_binary_scalars(BinaryPredicateOp::LessThan, &col, val, inverted) } - fn eval_le(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + fn eval_pred_le(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { let col = self.resolve_column(col)?; - self.eval_binary_scalars(BinaryOperator::LessThanOrEqual, &col, val, inverted) + self.eval_pred_binary_scalars(BinaryPredicateOp::LessThanOrEqual, &col, val, inverted) } - fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + fn eval_pred_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { let col = self.resolve_column(col)?; - self.eval_binary_scalars(BinaryOperator::Equal, &col, val, inverted) + self.eval_pred_binary_scalars(BinaryPredicateOp::Equal, &col, val, inverted) } - fn eval_binary_scalars( + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, - ) -> Option { - KernelPredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) + ) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_binary_scalars(op, left, right, inverted) } - fn eval_binary_columns( + fn eval_pred_binary_columns( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &ColumnName, right: &ColumnName, inverted: bool, - ) -> Option { + ) -> Option { let left = self.resolve_column(left)?; let right = self.resolve_column(right)?; - self.eval_binary_scalars(op, &left, &right, inverted) + self.eval_pred_binary_scalars(op, &left, &right, inverted) } - fn finish_eval_junction( + fn finish_eval_pred_junction( &self, - op: JunctionOperator, - exprs: impl IntoIterator>, + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, ) -> Option { - KernelPredicateEvaluatorDefaults::finish_eval_junction(op, exprs, inverted) + KernelPredicateEvaluatorDefaults::finish_eval_pred_junction(op, preds, inverted) } } @@ -615,8 +644,12 @@ impl KernelPredicateEvaluator for DefaultKernelPredica /// min/max stats, and NULL checks are converted into comparisons involving the column's nullcount /// and rowcount stats. pub(crate) trait DataSkippingPredicateEvaluator { - /// The output type produced by this expression evaluator - type Output; + /// The expression type produced by this predicate evaluator + type Expression; + + /// The predicate type produced by this predicate evaluator + type Predicate: Into; + /// The type of min and max column stats type TypedStat; /// The type of nullcount and rowcount column stats @@ -634,11 +667,11 @@ pub(crate) trait DataSkippingPredicateEvaluator { /// Retrieves the row count of a column (parquet footers always include this stat). fn get_rowcount_stat(&self) -> Option; - /// See [`KernelPredicateEvaluator::eval_scalar_is_null`] - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option; + /// See [`KernelPredicateEvaluator::eval_pred_scalar`] + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option; - /// See [`KernelPredicateEvaluator::eval_scalar`] - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option; + /// See [`KernelPredicateEvaluator::eval_pred_scalar_is_null`] + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option; /// For IS NULL (IS NOT NULL), we can only skip the file if all-null (no-null). Any other /// nullcount always forces us to keep the file. @@ -647,24 +680,24 @@ pub(crate) trait DataSkippingPredicateEvaluator { /// all-null or logically no-null, even tho the physical stats indicate a mix of null and /// non-null values. They cannot invalidate a file's physical all-null or non-null status, /// however, so the worst that can happen is we fail to skip an unnecessary file. - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option; + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option; - /// See [`KernelPredicateEvaluator::eval_binary_scalars`] - fn eval_binary_scalars( + /// See [`KernelPredicateEvaluator::eval_pred_binary_scalars`] + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, - ) -> Option; + ) -> Option; - /// See [`KernelPredicateEvaluator::finish_eval_junction`] - fn finish_eval_junction( + /// See [`KernelPredicateEvaluator::finish_eval_pred_junction`] + fn finish_eval_pred_junction( &self, - op: JunctionOperator, - exprs: impl IntoIterator>, + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, - ) -> Option; + ) -> Option; /// Helper method that performs a (possibly inverted) partial comparison between a typed column /// stat and a scalar. @@ -674,7 +707,7 @@ pub(crate) trait DataSkippingPredicateEvaluator { col: Self::TypedStat, val: &Scalar, inverted: bool, - ) -> Option; + ) -> Option; /// Performs a partial comparison against a column min-stat. See /// [`KernelPredicateEvaluatorDefaults::partial_cmp_scalars`] for details of the comparison semantics. @@ -684,7 +717,7 @@ pub(crate) trait DataSkippingPredicateEvaluator { val: &Scalar, ord: Ordering, inverted: bool, - ) -> Option { + ) -> Option { let min = self.get_min_stat(col, &val.data_type())?; self.eval_partial_cmp(ord, min, val, inverted) } @@ -697,13 +730,18 @@ pub(crate) trait DataSkippingPredicateEvaluator { val: &Scalar, ord: Ordering, inverted: bool, - ) -> Option { + ) -> Option { let max = self.get_max_stat(col, &val.data_type())?; self.eval_partial_cmp(ord, max, val, inverted) } - /// See [`KernelPredicateEvaluator::eval_lt`] - fn eval_lt(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + /// See [`KernelPredicateEvaluator::eval_pred_lt`] + fn eval_pred_lt( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { if inverted { // Given `col >= val`: // Skip if `val is greater than _every_ value in [min, max], implies @@ -724,8 +762,13 @@ pub(crate) trait DataSkippingPredicateEvaluator { } } - /// See [`KernelPredicateEvaluator::eval_le`] - fn eval_le(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + /// See [`KernelPredicateEvaluator::eval_pred_le`] + fn eval_pred_le( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { if inverted { // Given `col > val`: // Skip if `val` is not less than _all_ values in [min, max], implies @@ -746,80 +789,101 @@ pub(crate) trait DataSkippingPredicateEvaluator { } } - /// See [`KernelPredicateEvaluator::eval_ge`] - fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { - let (op, exprs) = if inverted { + /// See [`KernelPredicateEvaluator::eval_pred_ge`] + fn eval_pred_eq( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { + let (op, preds) = if inverted { // Column could compare not-equal if min or max value differs from the literal. - let exprs = [ + let preds = [ self.partial_cmp_min_stat(col, val, Ordering::Equal, true), self.partial_cmp_max_stat(col, val, Ordering::Equal, true), ]; - (JunctionOperator::Or, exprs) + (JunctionPredicateOp::Or, preds) } else { // Column could compare equal if its min/max values bracket the literal. - let exprs = [ + let preds = [ self.partial_cmp_min_stat(col, val, Ordering::Greater, true), self.partial_cmp_max_stat(col, val, Ordering::Less, true), ]; - (JunctionOperator::And, exprs) + (JunctionPredicateOp::And, preds) }; - self.finish_eval_junction(op, exprs, false) + self.finish_eval_pred_junction(op, preds, false) } } impl KernelPredicateEvaluator for T { - type Output = T::Output; + type Predicate = T::Predicate; + type Expression = T::Expression; - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { - self.eval_scalar_is_null(val, inverted) + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option { + self.eval_pred_scalar(val, inverted) } - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { - self.eval_scalar(val, inverted) + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { + self.eval_pred_scalar_is_null(val, inverted) } - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { - self.eval_is_null(col, inverted) + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + self.eval_pred_is_null(col, inverted) } - fn eval_lt(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { - self.eval_lt(col, val, inverted) + fn eval_pred_lt( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { + self.eval_pred_lt(col, val, inverted) } - fn eval_le(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { - self.eval_le(col, val, inverted) + fn eval_pred_le( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { + self.eval_pred_le(col, val, inverted) } - fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { - self.eval_eq(col, val, inverted) + fn eval_pred_eq( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { + self.eval_pred_eq(col, val, inverted) } - fn eval_binary_scalars( + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, - ) -> Option { - self.eval_binary_scalars(op, left, right, inverted) + ) -> Option { + self.eval_pred_binary_scalars(op, left, right, inverted) } - fn eval_binary_columns( + fn eval_pred_binary_columns( &self, - _op: BinaryOperator, + _op: BinaryPredicateOp, _a: &ColumnName, _b: &ColumnName, _inverted: bool, - ) -> Option { + ) -> Option { None // Unsupported } - fn finish_eval_junction( + fn finish_eval_pred_junction( &self, - op: JunctionOperator, - exprs: impl IntoIterator>, + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, - ) -> Option { - self.finish_eval_junction(op, exprs, inverted) + ) -> Option { + self.finish_eval_pred_junction(op, preds, inverted) } } diff --git a/kernel/src/kernel_predicates/parquet_stats_skipping.rs b/kernel/src/kernel_predicates/parquet_stats_skipping.rs index 146c73b1b..6f25c62c3 100644 --- a/kernel/src/kernel_predicates/parquet_stats_skipping.rs +++ b/kernel/src/kernel_predicates/parquet_stats_skipping.rs @@ -1,5 +1,5 @@ //! An implementation of data skipping that leverages parquet stats from the file footer. -use crate::expressions::{BinaryOperator, ColumnName, JunctionOperator, Scalar}; +use crate::expressions::{BinaryPredicateOp, ColumnName, JunctionPredicateOp, Scalar}; use crate::kernel_predicates::{DataSkippingPredicateEvaluator, KernelPredicateEvaluatorDefaults}; use crate::schema::DataType; use std::cmp::Ordering; @@ -30,7 +30,8 @@ pub(crate) trait ParquetStatsProvider { /// Blanket implementation that converts a [`ParquetStatsProvider`] into a /// [`DataSkippingPredicateEvaluator`]. impl DataSkippingPredicateEvaluator for T { - type Output = bool; + type Predicate = bool; + type Expression = Scalar; type TypedStat = Scalar; type IntStat = i64; @@ -60,15 +61,15 @@ impl DataSkippingPredicateEvaluator for T { KernelPredicateEvaluatorDefaults::partial_cmp_scalars(ord, &col, val, inverted) } - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar_is_null(val, inverted) + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar(val, inverted) } - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar(val, inverted) + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar_is_null(val, inverted) } - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option { let safe_to_skip = match inverted { true => self.get_rowcount_stat()?, // all-null false => 0i64, // no-null @@ -76,22 +77,22 @@ impl DataSkippingPredicateEvaluator for T { Some(self.get_nullcount_stat(col)? != safe_to_skip) } - fn eval_binary_scalars( + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, ) -> Option { - KernelPredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) + KernelPredicateEvaluatorDefaults::eval_pred_binary_scalars(op, left, right, inverted) } - fn finish_eval_junction( + fn finish_eval_pred_junction( &self, - op: JunctionOperator, - exprs: impl IntoIterator>, + op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, ) -> Option { - KernelPredicateEvaluatorDefaults::finish_eval_junction(op, exprs, inverted) + KernelPredicateEvaluatorDefaults::finish_eval_pred_junction(op, preds, inverted) } } diff --git a/kernel/src/kernel_predicates/parquet_stats_skipping/tests.rs b/kernel/src/kernel_predicates/parquet_stats_skipping/tests.rs index a1ad8fc36..44d79e23d 100644 --- a/kernel/src/kernel_predicates/parquet_stats_skipping/tests.rs +++ b/kernel/src/kernel_predicates/parquet_stats_skipping/tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::expressions::{column_expr, Expression as Expr}; +use crate::expressions::{column_expr, Expression as Expr, Predicate as Pred}; use crate::kernel_predicates::KernelPredicateEvaluator as _; use crate::DataType; @@ -43,7 +43,7 @@ impl ParquetStatsProvider for UnimplementedTestFilter { /// Tests apply_junction and apply_scalar #[test] fn test_junctions() { - use JunctionOperator::*; + use JunctionPredicateOp::*; let test_cases = &[ // Every combo of 0, 1 and 2 inputs @@ -93,28 +93,28 @@ fn test_junctions() { let inputs: Vec<_> = inputs .iter() .map(|val| match val { - Some(v) => Expr::literal(v), - None => Expr::null_literal(DataType::BOOLEAN), + Some(v) => Pred::literal(*v), + None => Pred::null_literal(), }) .collect(); expect_eq!( - filter.eval_junction(And, &inputs, false), + filter.eval_pred_junction(And, &inputs, false), *expect_and, "AND({inputs:?})" ); expect_eq!( - filter.eval_junction(Or, &inputs, false), + filter.eval_pred_junction(Or, &inputs, false), *expect_or, "OR({inputs:?})" ); expect_eq!( - filter.eval_junction(And, &inputs, true), + filter.eval_pred_junction(And, &inputs, true), expect_and.map(|val| !val), "NOT(AND({inputs:?}))" ); expect_eq!( - filter.eval_junction(Or, &inputs, true), + filter.eval_pred_junction(Or, &inputs, true), expect_or.map(|val| !val), "NOT(OR({inputs:?}))" ); @@ -160,19 +160,19 @@ fn test_eval_binary_comparisons() { const FIFTEEN: Scalar = Scalar::Integer(15); const NULL_VAL: Scalar = Scalar::Null(DataType::INTEGER); - let expressions = [ - Expr::lt(column_expr!("x"), 10), - Expr::le(column_expr!("x"), 10), - Expr::eq(column_expr!("x"), 10), - Expr::ne(column_expr!("x"), 10), - Expr::gt(column_expr!("x"), 10), - Expr::ge(column_expr!("x"), 10), + let predicates = [ + Pred::lt(column_expr!("x"), Expr::literal(10)), + Pred::le(column_expr!("x"), Expr::literal(10)), + Pred::eq(column_expr!("x"), Expr::literal(10)), + Pred::ne(column_expr!("x"), Expr::literal(10)), + Pred::gt(column_expr!("x"), Expr::literal(10)), + Pred::ge(column_expr!("x"), Expr::literal(10)), ]; let do_test = |min: Scalar, max: Scalar, expected: &[Option]| { let filter = MinMaxTestFilter::new(Some(min.clone()), Some(max.clone())); - for (expr, expect) in expressions.iter().zip(expected.iter()) { - expect_eq!(filter.eval(expr), *expect, "{expr:#?} with [{min}..{max}]"); + for (pred, expect) in predicates.iter().zip(expected.iter()) { + expect_eq!(filter.eval(pred), *expect, "{pred:#?} with [{min}..{max}]"); } }; @@ -229,8 +229,8 @@ impl ParquetStatsProvider for NullCountTestFilter { #[test] fn test_eval_is_null() { let expressions = [ - Expr::is_null(column_expr!("x")), - Expr::is_not_null(column_expr!("x")), + Pred::is_null(column_expr!("x")), + Pred::is_not_null(column_expr!("x")), ]; let do_test = |nullcount: i64, expected: &[Option]| { diff --git a/kernel/src/kernel_predicates/tests.rs b/kernel/src/kernel_predicates/tests.rs index 88899c2da..cb158842b 100644 --- a/kernel/src/kernel_predicates/tests.rs +++ b/kernel/src/kernel_predicates/tests.rs @@ -1,5 +1,8 @@ use super::*; -use crate::expressions::{column_expr, column_name, ArrayData, Expression, StructData}; +use crate::expressions::{ + column_expr, column_name, column_pred, ArrayData, Expression as Expr, Predicate as Pred, + StructData, +}; use crate::schema::ArrayType; use crate::DataType; @@ -41,7 +44,7 @@ fn test_default_eval_scalar() { ]; for (value, inverted, expect) in test_cases.into_iter() { assert_eq!( - KernelPredicateEvaluatorDefaults::eval_scalar(&value, inverted), + KernelPredicateEvaluatorDefaults::eval_pred_scalar(&value, inverted), expect, "value: {value:?} inverted: {inverted}" ); @@ -187,11 +190,11 @@ fn test_default_partial_cmp_scalars() { // Verifies that eval_binary_scalars uses partial_cmp_scalars correctly #[test] fn test_eval_binary_scalars() { - use BinaryOperator::*; + use BinaryPredicateOp::*; let smaller_value = Scalar::Long(1); let larger_value = Scalar::Long(10); for inverted in [true, false] { - let compare = KernelPredicateEvaluatorDefaults::eval_binary_scalars; + let compare = KernelPredicateEvaluatorDefaults::eval_pred_binary_scalars; expect_eq!( compare(Equal, &smaller_value, &smaller_value, inverted), Some(!inverted), @@ -272,12 +275,12 @@ fn test_eval_binary_columns() { let y = column_expr!("y"); for inverted in [true, false] { assert_eq!( - filter.eval_binary(BinaryOperator::Equal, &x, &y, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Equal, &x, &y, inverted), Some(inverted), "x = y (inverted: {inverted})" ); assert_eq!( - filter.eval_binary(BinaryOperator::Equal, &x, &x, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Equal, &x, &x, inverted), Some(!inverted), "x = x (inverted: {inverted})" ); @@ -309,21 +312,20 @@ fn test_eval_junction() { for (inputs, expect_and, expect_or) in test_cases.iter() { let inputs: Vec<_> = inputs .iter() - .cloned() .map(|v| match v { - Some(v) => Expression::literal(v), - None => Expression::null_literal(DataType::BOOLEAN), + Some(v) => Pred::literal(*v), + None => Pred::null_literal(), }) .collect(); for inverted in [true, false] { let invert_if_needed = |v: &Option<_>| v.map(|v| v != inverted); expect_eq!( - filter.eval_junction(JunctionOperator::And, &inputs, inverted), + filter.eval_pred_junction(JunctionPredicateOp::And, &inputs, inverted), invert_if_needed(expect_and), "AND({inputs:?}) (inverted: {inverted})" ); expect_eq!( - filter.eval_junction(JunctionOperator::Or, &inputs, inverted), + filter.eval_pred_junction(JunctionPredicateOp::Or, &inputs, inverted), invert_if_needed(expect_or), "OR({inputs:?}) (inverted: {inverted})" ); @@ -339,12 +341,12 @@ fn test_eval_column() { (Scalar::Null(DataType::BOOLEAN), None), (Scalar::from(1), None), ]; - let col = &column_name!("x"); + let col = &column_expr!("x"); for (input, expect) in &test_cases { let filter = DefaultKernelPredicateEvaluator::from(input.clone()); for inverted in [true, false] { expect_eq!( - filter.eval_column(col, inverted), + filter.eval_pred_expr(col, inverted), expect.map(|v| v != inverted), "{input:?} (inverted: {inverted})" ); @@ -362,10 +364,10 @@ fn test_eval_not() { ]; let filter = DefaultKernelPredicateEvaluator::from(UnimplementedColumnResolver); for (input, expect) in test_cases { - let input = input.into(); + let input = Pred::from_expr(input); for inverted in [true, false] { expect_eq!( - filter.eval_not(&input, inverted), + filter.eval_pred_not(&input, inverted), expect.map(|v| v != inverted), "NOT({input:?}) (inverted: {inverted})" ); @@ -375,28 +377,28 @@ fn test_eval_not() { #[test] fn test_eval_is_null() { - use crate::expressions::UnaryOperator::IsNull; + use crate::expressions::UnaryPredicateOp::IsNull; let expr = column_expr!("x"); let filter = DefaultKernelPredicateEvaluator::from(Scalar::from(1)); expect_eq!( - filter.eval_unary(IsNull, &expr, true), + filter.eval_pred_unary(IsNull, &expr, true), Some(true), "x IS NOT NULL" ); expect_eq!( - filter.eval_unary(IsNull, &expr, false), + filter.eval_pred_unary(IsNull, &expr, false), Some(false), "x IS NULL" ); - let expr = Expression::literal(1); + let expr = Expr::literal(1); expect_eq!( - filter.eval_unary(IsNull, &expr, true), + filter.eval_pred_unary(IsNull, &expr, true), Some(true), "1 IS NOT NULL" ); expect_eq!( - filter.eval_unary(IsNull, &expr, false), + filter.eval_pred_unary(IsNull, &expr, false), Some(false), "1 IS NULL" ); @@ -410,54 +412,54 @@ fn test_eval_distinct() { let filter = DefaultKernelPredicateEvaluator::from(one.clone()); let col = &column_name!("x"); expect_eq!( - filter.eval_distinct(col, one, true), + filter.eval_pred_distinct(col, one, true), Some(true), "NOT DISTINCT(x, 1) (x = 1)" ); expect_eq!( - filter.eval_distinct(col, one, false), + filter.eval_pred_distinct(col, one, false), Some(false), "DISTINCT(x, 1) (x = 1)" ); expect_eq!( - filter.eval_distinct(col, two, true), + filter.eval_pred_distinct(col, two, true), Some(false), "NOT DISTINCT(x, 2) (x = 1)" ); expect_eq!( - filter.eval_distinct(col, two, false), + filter.eval_pred_distinct(col, two, false), Some(true), "DISTINCT(x, 2) (x = 1)" ); expect_eq!( - filter.eval_distinct(col, null, true), + filter.eval_pred_distinct(col, null, true), Some(false), "NOT DISTINCT(x, NULL) (x = 1)" ); expect_eq!( - filter.eval_distinct(col, null, false), + filter.eval_pred_distinct(col, null, false), Some(true), "DISTINCT(x, NULL) (x = 1)" ); let filter = DefaultKernelPredicateEvaluator::from(null.clone()); expect_eq!( - filter.eval_distinct(col, one, true), + filter.eval_pred_distinct(col, one, true), Some(false), "NOT DISTINCT(x, 1) (x = NULL)" ); expect_eq!( - filter.eval_distinct(col, one, false), + filter.eval_pred_distinct(col, one, false), Some(true), "DISTINCT(x, 1) (x = NULL)" ); expect_eq!( - filter.eval_distinct(col, null, true), + filter.eval_pred_distinct(col, null, true), Some(true), "NOT DISTINCT(x, NULL) (x = NULL)" ); expect_eq!( - filter.eval_distinct(col, null, false), + filter.eval_pred_distinct(col, null, false), Some(false), "DISTINCT(x, NULL) (x = NULL)" ); @@ -467,102 +469,81 @@ fn test_eval_distinct() { // test_eval_binary_scalars. #[test] fn eval_binary() { + use crate::expressions::BinaryPredicateOp; + let col = column_expr!("x"); - let val = Expression::literal(10); + let val = Expr::literal(10); let filter = DefaultKernelPredicateEvaluator::from(Scalar::from(1)); - // unsupported - expect_eq!( - filter.eval_binary(BinaryOperator::Plus, &col, &val, false), - None, - "x + 10" - ); - expect_eq!( - filter.eval_binary(BinaryOperator::Minus, &col, &val, false), - None, - "x - 10" - ); - expect_eq!( - filter.eval_binary(BinaryOperator::Multiply, &col, &val, false), - None, - "x * 10" - ); - expect_eq!( - filter.eval_binary(BinaryOperator::Divide, &col, &val, false), - None, - "x / 10" - ); - - // supported for inverted in [true, false] { expect_eq!( - filter.eval_binary(BinaryOperator::LessThan, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::LessThan, &col, &val, inverted), Some(!inverted), "x < 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::LessThanOrEqual, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::LessThanOrEqual, &col, &val, inverted), Some(!inverted), "x <= 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::Equal, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Equal, &col, &val, inverted), Some(inverted), "x = 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::NotEqual, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::NotEqual, &col, &val, inverted), Some(!inverted), "x != 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::GreaterThanOrEqual, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::GreaterThanOrEqual, &col, &val, inverted), Some(inverted), "x >= 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::GreaterThan, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::GreaterThan, &col, &val, inverted), Some(inverted), "x > 10 (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::Distinct, &col, &val, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Distinct, &col, &val, inverted), Some(!inverted), "DISTINCT(x, 10) (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::LessThan, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::LessThan, &val, &col, inverted), Some(inverted), "10 < x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::LessThanOrEqual, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::LessThanOrEqual, &val, &col, inverted), Some(inverted), "10 <= x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::Equal, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Equal, &val, &col, inverted), Some(inverted), "10 = x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::NotEqual, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::NotEqual, &val, &col, inverted), Some(!inverted), "10 != x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::GreaterThanOrEqual, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::GreaterThanOrEqual, &val, &col, inverted), Some(!inverted), "10 >= x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::GreaterThan, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::GreaterThan, &val, &col, inverted), Some(!inverted), "10 > x (inverted: {inverted})" ); expect_eq!( - filter.eval_binary(BinaryOperator::Distinct, &val, &col, inverted), + filter.eval_pred_binary(BinaryPredicateOp::Distinct, &val, &col, inverted), Some(!inverted), "DISTINCT(10, x) (inverted: {inverted})" ); @@ -580,77 +561,94 @@ impl ResolveColumnAsScalar for NullColumnResolver { #[test] fn test_sql_where() { let col = &column_expr!("x"); + let col_pred = &column_pred!("x"); const VAL: Expr = Expr::Literal(Scalar::Integer(1)); - const NULL: Expr = Expr::Literal(Scalar::Null(DataType::BOOLEAN)); - const FALSE: Expr = Expr::Literal(Scalar::Boolean(false)); - const TRUE: Expr = Expr::Literal(Scalar::Boolean(true)); + const NULL: Pred = Pred::null_literal(); + const FALSE: Pred = Pred::literal(false); + const TRUE: Pred = Pred::literal(true); let null_filter = DefaultKernelPredicateEvaluator::from(NullColumnResolver); let empty_filter = DefaultKernelPredicateEvaluator::from(EmptyColumnResolver); // Basic sanity check - expect_eq!(null_filter.eval_sql_where(&VAL), None, "WHERE {VAL}"); - expect_eq!(empty_filter.eval_sql_where(&VAL), None, "WHERE {VAL}"); + expect_eq!( + null_filter.eval_sql_where(&Pred::from_expr(VAL)), + None, + "WHERE {VAL}" + ); + expect_eq!( + empty_filter.eval_sql_where(&Pred::from_expr(VAL)), + None, + "WHERE {VAL}" + ); - expect_eq!(null_filter.eval_sql_where(col), Some(false), "WHERE {col}"); - expect_eq!(empty_filter.eval_sql_where(col), None, "WHERE {col}"); + expect_eq!( + null_filter.eval_sql_where(col_pred), + Some(false), + "WHERE {col_pred}" + ); + expect_eq!( + empty_filter.eval_sql_where(col_pred), + None, + "WHERE {col_pred}" + ); // SQL eval does not modify behavior of IS NULL - let expr = &Expr::is_null(col.clone()); - expect_eq!(null_filter.eval_sql_where(expr), Some(true), "{expr}"); + let pred = &Pred::is_null(col.clone()); + expect_eq!(null_filter.eval_sql_where(pred), Some(true), "{pred}"); // NOT a gets skipped when NULL but not when missing - let expr = &Expr::not(col.clone()); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::not(col_pred.clone()); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); // Injected NULL checks only short circuit if inputs are NULL - let expr = &Expr::lt(FALSE, TRUE); - expect_eq!(null_filter.eval_sql_where(expr), Some(true), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), Some(true), "{expr}"); + let pred = &Pred::lt(FALSE, TRUE); + expect_eq!(null_filter.eval_sql_where(pred), Some(true), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), Some(true), "{pred}"); // Contrast normal vs SQL WHERE semantics - comparison - let expr = &Expr::lt(col.clone(), VAL); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); + let pred = &Pred::lt(col.clone(), VAL); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); // NULL check produces NULL due to missing column - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); - let expr = &Expr::lt(VAL, col.clone()); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::lt(VAL, col.clone()); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); - let expr = &Expr::distinct(VAL, col.clone()); - expect_eq!(null_filter.eval(expr), Some(true), "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(true), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::distinct(VAL, col.clone()); + expect_eq!(null_filter.eval(pred), Some(true), "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(true), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); - let expr = &Expr::distinct(NULL, col.clone()); - expect_eq!(null_filter.eval(expr), Some(false), "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::distinct(NULL, col.clone()); + expect_eq!(null_filter.eval(pred), Some(false), "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); // Contrast normal vs SQL WHERE semantics - comparison inside AND - let expr = &Expr::and(TRUE, Expr::lt(col.clone(), VAL)); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::and(TRUE, Pred::lt(col.clone(), VAL)); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); // NULL allows static skipping under SQL semantics - let expr = &Expr::and(NULL, Expr::lt(col.clone(), VAL)); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), Some(false), "{expr}"); + let pred = &Pred::and(NULL, Pred::lt(col.clone(), VAL)); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), Some(false), "{pred}"); // Contrast normal vs. SQL WHERE semantics - comparison inside AND inside AND - let expr = &Expr::and(TRUE, Expr::and(TRUE, Expr::lt(col.clone(), VAL))); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::and(TRUE, Pred::and(TRUE, Pred::lt(col.clone(), VAL))); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); // Ditto for comparison inside OR inside AND - let expr = &Expr::or(FALSE, Expr::and(TRUE, Expr::lt(col.clone(), VAL))); - expect_eq!(null_filter.eval(expr), None, "{expr}"); - expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); - expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); + let pred = &Pred::or(FALSE, Pred::and(TRUE, Pred::lt(col.clone(), VAL))); + expect_eq!(null_filter.eval(pred), None, "{pred}"); + expect_eq!(null_filter.eval_sql_where(pred), Some(false), "{pred}"); + expect_eq!(empty_filter.eval_sql_where(pred), None, "{pred}"); } diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 2e3d3e2bb..57af8c80f 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -100,7 +100,7 @@ internal_mod!(pub(crate) mod log_segment); pub use delta_kernel_derive; pub use engine_data::{EngineData, RowVisitor}; pub use error::{DeltaResult, Error}; -pub use expressions::{Expression, ExpressionRef}; +pub use expressions::{Expression, ExpressionRef, Predicate, PredicateRef}; pub use table::Table; use expressions::literal_expression_transform::LiteralExpressionTransform; @@ -330,11 +330,23 @@ impl AsAny for T { pub trait ExpressionEvaluator: AsAny { /// Evaluate the expression on a given EngineData. /// - /// Contains one value for each row of the input. + /// Produces one value for each row of the input. /// The data type of the output is same as the type output of the expression this evaluator is using. fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult>; } +/// Trait for implementing a Predicate evaluator. +/// +/// It contains one Predicate which can be evaluated on multiple ColumnarBatches. +/// Connectors can implement this trait to optimize the evaluation using the +/// connector specific capabilities. +pub trait PredicateEvaluator: AsAny { + /// Evaluate the predicate on a given EngineData. + /// + /// Produces one boolean value for each row of the input. + fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult>; +} + /// Provides expression evaluation capability to Delta Kernel. /// /// Delta Kernel can use this handler to evaluate predicate on partition filters, @@ -358,6 +370,21 @@ pub trait EvaluationHandler: AsAny { output_type: DataType, ) -> Arc; + /// Create a [`PredicateEvaluator`] that can evaluate the given [`Predicate`] on columnar + /// batches with the given [`Schema`] to produce a column of boolean results. + /// + /// # Parameters + /// + /// - `schema`: Schema of the input data. + /// - `predicate`: Predicate to evaluate. + /// + /// [`Schema`]: crate::schema::StructType + fn new_predicate_evaluator( + &self, + schema: SchemaRef, + predicate: Predicate, + ) -> Arc; + /// Create a single-row all-null-value [`EngineData`] with the schema specified by /// `output_schema`. // NOTE: we should probably allow DataType instead of SchemaRef, but can expand that in the @@ -454,7 +481,7 @@ pub trait JsonHandler: AsAny { &self, files: &[FileMeta], physical_schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult; /// Atomically (!) write a single JSON file. Each row of the input data should be written as a @@ -505,7 +532,7 @@ pub trait ParquetHandler: AsAny { &self, files: &[FileMeta], physical_schema: SchemaRef, - predicate: Option, + predicate: Option, ) -> DeltaResult; } diff --git a/kernel/src/log_segment.rs b/kernel/src/log_segment.rs index fac0ac502..c7d2e2bf9 100644 --- a/kernel/src/log_segment.rs +++ b/kernel/src/log_segment.rs @@ -14,8 +14,8 @@ use crate::schema::SchemaRef; use crate::snapshot::LastCheckpointHint; use crate::utils::require; use crate::{ - DeltaResult, Engine, EngineData, Error, Expression, ExpressionRef, ParquetHandler, RowVisitor, - StorageHandler, Version, + DeltaResult, Engine, EngineData, Error, Expression, ParquetHandler, Predicate, PredicateRef, + RowVisitor, StorageHandler, Version, }; use delta_kernel_derive::internal_api; @@ -209,7 +209,7 @@ impl LogSegment { engine: &dyn Engine, commit_read_schema: SchemaRef, checkpoint_read_schema: SchemaRef, - meta_predicate: Option, + meta_predicate: Option, ) -> DeltaResult, bool)>> + Send> { // `replay` expects commit files to be sorted in descending order, so we reverse the sorted // commit files @@ -245,7 +245,7 @@ impl LogSegment { &self, engine: &dyn Engine, checkpoint_read_schema: SchemaRef, - meta_predicate: Option, + meta_predicate: Option, ) -> DeltaResult, bool)>> + Send> { let need_file_actions = checkpoint_read_schema.contains(ADD_NAME) || checkpoint_read_schema.contains(REMOVE_NAME); @@ -341,7 +341,7 @@ impl LogSegment { log_root: Url, batch: &dyn EngineData, checkpoint_read_schema: SchemaRef, - meta_predicate: Option, + meta_predicate: Option, ) -> DeltaResult>> + Send>> { // Visit the rows of the checkpoint batch to extract sidecar file references let mut visitor = SidecarVisitor::default(); @@ -407,8 +407,8 @@ impl LogSegment { ) -> DeltaResult, bool)>> + Send> { let schema = get_log_schema().project(&[PROTOCOL_NAME, METADATA_NAME])?; // filter out log files that do not contain metadata or protocol information - static META_PREDICATE: LazyLock> = LazyLock::new(|| { - Some(Arc::new(Expression::or( + static META_PREDICATE: LazyLock> = LazyLock::new(|| { + Some(Arc::new(Predicate::or( Expression::column([METADATA_NAME, "id"]).is_not_null(), Expression::column([PROTOCOL_NAME, "minReaderVersion"]).is_not_null(), ))) diff --git a/kernel/src/log_segment/tests.rs b/kernel/src/log_segment/tests.rs index d18bcfb3a..a0751c2a2 100644 --- a/kernel/src/log_segment/tests.rs +++ b/kernel/src/log_segment/tests.rs @@ -25,7 +25,7 @@ use crate::scan::test_utils::{ use crate::snapshot::LastCheckpointHint; use crate::utils::test_utils::{assert_batch_matches, Action}; use crate::{ - DeltaResult, Engine as _, EngineData, Expression, ExpressionRef, FileMeta, RowVisitor, + DeltaResult, Engine as _, EngineData, Expression, FileMeta, PredicateRef, RowVisitor, StorageHandler, Table, }; use test_utils::delta_path_for_version; @@ -932,7 +932,7 @@ fn test_reading_sidecar_files_with_predicate() -> DeltaResult<()> { )?; // Filter out sidecar files that do not contain remove actions - let remove_predicate: LazyLock> = LazyLock::new(|| { + let remove_predicate: LazyLock> = LazyLock::new(|| { Some(Arc::new( Expression::column([REMOVE_NAME, "path"]).is_not_null(), )) diff --git a/kernel/src/scan/data_skipping.rs b/kernel/src/scan/data_skipping.rs index a4b1456e1..6287dcb3e 100644 --- a/kernel/src/scan/data_skipping.rs +++ b/kernel/src/scan/data_skipping.rs @@ -8,14 +8,16 @@ use crate::actions::get_log_add_schema; use crate::actions::visitors::SelectionVectorVisitor; use crate::error::DeltaResult; use crate::expressions::{ - column_expr, joined_column_expr, BinaryOperator, ColumnName, Expression as Expr, ExpressionRef, - JunctionOperator, Scalar, + column_expr, joined_column_expr, BinaryPredicateOp, ColumnName, Expression as Expr, + JunctionPredicateOp, Predicate as Pred, PredicateRef, Scalar, }; use crate::kernel_predicates::{ DataSkippingPredicateEvaluator, KernelPredicateEvaluator, KernelPredicateEvaluatorDefaults, }; use crate::schema::{DataType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType}; -use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler, RowVisitor as _}; +use crate::{ + Engine, EngineData, ExpressionEvaluator, JsonHandler, PredicateEvaluator, RowVisitor as _, +}; #[cfg(test)] mod tests; @@ -35,23 +37,23 @@ mod tests; /// - `AND` is rewritten as a conjunction of the rewritten operands where we just skip operands that /// are not eligible for data skipping. /// - `OR` is rewritten only if all operands are eligible for data skipping. Otherwise, the whole OR -/// expression is dropped. +/// predicate is dropped. #[cfg(test)] -fn as_data_skipping_predicate(expr: &Expr) -> Option { - DataSkippingPredicateCreator.eval(expr) +fn as_data_skipping_predicate(pred: &Pred) -> Option { + DataSkippingPredicateCreator.eval(pred) } /// Like `as_data_skipping_predicate`, but invokes [`KernelPredicateEvaluator::eval_sql_where`] /// instead of [`KernelPredicateEvaluator::eval`]. -fn as_sql_data_skipping_predicate(expr: &Expr) -> Option { - DataSkippingPredicateCreator.eval_sql_where(expr) +fn as_sql_data_skipping_predicate(pred: &Pred) -> Option { + DataSkippingPredicateCreator.eval_sql_where(pred) } pub(crate) struct DataSkippingFilter { stats_schema: SchemaRef, select_stats_evaluator: Arc, - skipping_evaluator: Arc, - filter_evaluator: Arc, + skipping_evaluator: Arc, + filter_evaluator: Arc, json_handler: Arc, } @@ -63,14 +65,11 @@ impl DataSkippingFilter { /// but using an Option lets the engine easily avoid the overhead of applying trivial filters. pub(crate) fn new( engine: &dyn Engine, - physical_predicate: Option<(ExpressionRef, SchemaRef)>, + physical_predicate: Option<(PredicateRef, SchemaRef)>, ) -> Option { - static PREDICATE_SCHEMA: LazyLock = LazyLock::new(|| { - DataType::struct_type([StructField::nullable("predicate", DataType::BOOLEAN)]) - }); static STATS_EXPR: LazyLock = LazyLock::new(|| column_expr!("add.stats")); - static FILTER_EXPR: LazyLock = - LazyLock::new(|| column_expr!("predicate").distinct(false)); + static FILTER_PRED: LazyLock = + LazyLock::new(|| column_expr!("predicate").distinct(Expr::literal(false))); let (predicate, referenced_schema) = physical_predicate?; debug!("Creating a data skipping filter for {:#?}", predicate); @@ -140,17 +139,14 @@ impl DataSkippingFilter { DataType::STRING, ); - let skipping_evaluator = engine.evaluation_handler().new_expression_evaluator( + let skipping_evaluator = engine.evaluation_handler().new_predicate_evaluator( stats_schema.clone(), - Expr::struct_from([as_sql_data_skipping_predicate(&predicate)?]), - PREDICATE_SCHEMA.clone(), + as_sql_data_skipping_predicate(&predicate)?, ); - let filter_evaluator = engine.evaluation_handler().new_expression_evaluator( - stats_schema.clone(), - FILTER_EXPR.clone(), - DataType::BOOLEAN, - ); + let filter_evaluator = engine + .evaluation_handler() + .new_predicate_evaluator(stats_schema.clone(), FILTER_PRED.clone()); Some(Self { stats_schema, @@ -197,7 +193,8 @@ impl DataSkippingFilter { struct DataSkippingPredicateCreator; impl DataSkippingPredicateEvaluator for DataSkippingPredicateCreator { - type Output = Expr; + type Predicate = Pred; + type Expression = Expr; type TypedStat = Expr; type IntStat = Expr; @@ -227,51 +224,51 @@ impl DataSkippingPredicateEvaluator for DataSkippingPredicateCreator { col: Expr, val: &Scalar, inverted: bool, - ) -> Option { + ) -> Option { let op = match (ord, inverted) { - (Ordering::Less, false) => BinaryOperator::LessThan, - (Ordering::Less, true) => BinaryOperator::GreaterThanOrEqual, - (Ordering::Equal, false) => BinaryOperator::Equal, - (Ordering::Equal, true) => BinaryOperator::NotEqual, - (Ordering::Greater, false) => BinaryOperator::GreaterThan, - (Ordering::Greater, true) => BinaryOperator::LessThanOrEqual, + (Ordering::Less, false) => BinaryPredicateOp::LessThan, + (Ordering::Less, true) => BinaryPredicateOp::GreaterThanOrEqual, + (Ordering::Equal, false) => BinaryPredicateOp::Equal, + (Ordering::Equal, true) => BinaryPredicateOp::NotEqual, + (Ordering::Greater, false) => BinaryPredicateOp::GreaterThan, + (Ordering::Greater, true) => BinaryPredicateOp::LessThanOrEqual, }; - Some(Expr::binary(op, col, val.clone())) + Some(Pred::binary(op, col, val.clone())) } - fn eval_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar_is_null(val, inverted).map(Expr::literal) + fn eval_pred_scalar(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar(val, inverted).map(Pred::literal) } - fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { - KernelPredicateEvaluatorDefaults::eval_scalar(val, inverted).map(Expr::literal) + fn eval_pred_scalar_is_null(&self, val: &Scalar, inverted: bool) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_scalar_is_null(val, inverted).map(Pred::literal) } - fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + fn eval_pred_is_null(&self, col: &ColumnName, inverted: bool) -> Option { let safe_to_skip = match inverted { true => self.get_rowcount_stat()?, // all-null false => Expr::literal(0i64), // no-null }; - Some(Expr::ne(self.get_nullcount_stat(col)?, safe_to_skip)) + Some(Pred::ne(self.get_nullcount_stat(col)?, safe_to_skip)) } - fn eval_binary_scalars( + fn eval_pred_binary_scalars( &self, - op: BinaryOperator, + op: BinaryPredicateOp, left: &Scalar, right: &Scalar, inverted: bool, - ) -> Option { - KernelPredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) - .map(Expr::literal) + ) -> Option { + KernelPredicateEvaluatorDefaults::eval_pred_binary_scalars(op, left, right, inverted) + .map(Pred::literal) } - fn finish_eval_junction( + fn finish_eval_pred_junction( &self, - mut op: JunctionOperator, - exprs: impl IntoIterator>, + mut op: JunctionPredicateOp, + preds: impl IntoIterator>, inverted: bool, - ) -> Option { + ) -> Option { if inverted { op = op.invert(); } @@ -282,16 +279,16 @@ impl DataSkippingPredicateEvaluator for DataSkippingPredicateCreator { // where FALSE would otherwise be expected. So, we filter out all nulls except the first, // observing that one NULL is enough to produce the correct behavior during predicate eval. let mut keep_null = true; - let exprs: Vec<_> = exprs + let preds: Vec<_> = preds .into_iter() - .flat_map(|e| match e { - Some(expr) => Some(expr), + .flat_map(|p| match p { + Some(pred) => Some(pred), None => keep_null.then(|| { keep_null = false; - Expr::null_literal(DataType::BOOLEAN) + Pred::null_literal() }), }) .collect(); - Some(Expr::junction(op, exprs)) + Some(Pred::junction(op, preds)) } } diff --git a/kernel/src/scan/data_skipping/tests.rs b/kernel/src/scan/data_skipping/tests.rs index 0f29dbd36..2cae5bd1b 100644 --- a/kernel/src/scan/data_skipping/tests.rs +++ b/kernel/src/scan/data_skipping/tests.rs @@ -25,7 +25,7 @@ macro_rules! expect_eq { #[test] fn test_eval_is_null() { let col = &column_expr!("x"); - let expressions = [Expr::is_null(col.clone()), Expr::is_not_null(col.clone())]; + let predicates = [Pred::is_null(col.clone()), Pred::is_not_null(col.clone())]; let do_test = |nullcount: i64, expected: &[Option]| { let resolver = HashMap::from_iter([ @@ -33,12 +33,12 @@ fn test_eval_is_null() { (column_name!("nullCount.x"), Scalar::from(nullcount)), ]); let filter = DefaultKernelPredicateEvaluator::from(resolver); - for (expr, expect) in expressions.iter().zip(expected) { - let pred = as_data_skipping_predicate(expr).unwrap(); + for (pred, expect) in predicates.iter().zip(expected) { + let skipping_pred = as_data_skipping_predicate(pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&skipping_pred, false), *expect, - "{expr:#?} became {pred:#?} ({nullcount} nulls)" + "{pred:#?} became {skipping_pred:#?} ({nullcount} nulls)" ); } }; @@ -61,13 +61,13 @@ fn test_eval_binary_comparisons() { let fifteen = &Scalar::from(15); let null = &Scalar::Null(DataType::INTEGER); - let expressions = [ - Expr::lt(col.clone(), ten.clone()), - Expr::le(col.clone(), ten.clone()), - Expr::eq(col.clone(), ten.clone()), - Expr::ne(col.clone(), ten.clone()), - Expr::gt(col.clone(), ten.clone()), - Expr::ge(col.clone(), ten.clone()), + let predicates = [ + Pred::lt(col.clone(), ten.clone()), + Pred::le(col.clone(), ten.clone()), + Pred::eq(col.clone(), ten.clone()), + Pred::ne(col.clone(), ten.clone()), + Pred::gt(col.clone(), ten.clone()), + Pred::ge(col.clone(), ten.clone()), ]; let do_test = |min: &Scalar, max: &Scalar, expected: &[Option]| { @@ -76,12 +76,12 @@ fn test_eval_binary_comparisons() { (column_name!("maxValues.x"), max.clone()), ]); let filter = DefaultKernelPredicateEvaluator::from(resolver); - for (expr, expect) in expressions.iter().zip(expected.iter()) { - let pred = as_data_skipping_predicate(expr).unwrap(); + for (pred, expect) in predicates.iter().zip(expected.iter()) { + let skipping_pred = as_data_skipping_predicate(pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&skipping_pred, false), *expect, - "{expr:#?} became {pred:#?} with [{min}..{max}]" + "{pred:#?} became {skipping_pred:#?} with [{min}..{max}]" ); } }; @@ -154,35 +154,35 @@ fn test_eval_junction() { let inputs: Vec<_> = inputs .iter() .map(|val| match val { - Some(v) => Expr::literal(v), - None => Expr::null_literal(DataType::BOOLEAN), + Some(v) => Pred::literal(*v), + None => Pred::null_literal(), }) .collect(); - let expr = Expr::and_from(inputs.clone()); - let pred = as_data_skipping_predicate(&expr).unwrap(); + let pred = Pred::and_from(inputs.clone()); + let pred = as_data_skipping_predicate(&pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&pred, false), *expect_and, "AND({inputs:?})" ); - let expr = Expr::or_from(inputs.clone()); - let pred = as_data_skipping_predicate(&expr).unwrap(); - expect_eq!(filter.eval_expr(&pred, false), *expect_or, "OR({inputs:?})"); + let pred = Pred::or_from(inputs.clone()); + let pred = as_data_skipping_predicate(&pred).unwrap(); + expect_eq!(filter.eval_pred(&pred, false), *expect_or, "OR({inputs:?})"); - let expr = Expr::not(Expr::and_from(inputs.clone())); - let pred = as_data_skipping_predicate(&expr).unwrap(); + let pred = Pred::not(Pred::and_from(inputs.clone())); + let pred = as_data_skipping_predicate(&pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&pred, false), expect_and.map(|val| !val), "NOT AND({inputs:?})" ); - let expr = Expr::not(Expr::or_from(inputs.clone())); - let pred = as_data_skipping_predicate(&expr).unwrap(); + let pred = Pred::not(Pred::or_from(inputs.clone())); + let pred = as_data_skipping_predicate(&pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&pred, false), expect_or.map(|val| !val), "NOT OR({inputs:?})" ); @@ -200,11 +200,11 @@ fn test_eval_distinct() { let fifteen = &Scalar::from(15); let null = &Scalar::Null(DataType::INTEGER); - let expressions = [ - Expr::distinct(col.clone(), ten.clone()), - Expr::not(Expr::distinct(col.clone(), ten.clone())), - Expr::distinct(col.clone(), null.clone()), - Expr::not(Expr::distinct(col.clone(), null.clone())), + let predicates = [ + Pred::distinct(col.clone(), ten.clone()), + Pred::not(Pred::distinct(col.clone(), ten.clone())), + Pred::distinct(col.clone(), null.clone()), + Pred::not(Pred::distinct(col.clone(), null.clone())), ]; let do_test = |min: &Scalar, max: &Scalar, nullcount: i64, expected: &[Option]| { @@ -215,12 +215,12 @@ fn test_eval_distinct() { (column_name!("maxValues.x"), max.clone()), ]); let filter = DefaultKernelPredicateEvaluator::from(resolver); - for (expr, expect) in expressions.iter().zip(expected) { - let pred = as_data_skipping_predicate(expr).unwrap(); + for (pred, expect) in predicates.iter().zip(expected) { + let skipping_pred = as_data_skipping_predicate(pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&skipping_pred, false), *expect, - "{expr:#?} became {pred:#?} ({min}..{max}, {nullcount} nulls)" + "{pred:#?} became {skipping_pred:#?} ({min}..{max}, {nullcount} nulls)" ); } }; @@ -257,16 +257,16 @@ fn test_eval_distinct() { fn test_sql_where() { let col = &column_expr!("x"); const VAL: Expr = Expr::Literal(Scalar::Integer(10)); - const NULL: Expr = Expr::Literal(Scalar::Null(DataType::BOOLEAN)); - const FALSE: Expr = Expr::Literal(Scalar::Boolean(false)); - const TRUE: Expr = Expr::Literal(Scalar::Boolean(true)); + const NULL: Pred = Pred::null_literal(); + const FALSE: Pred = Pred::literal(false); + const TRUE: Pred = Pred::literal(true); const ROWCOUNT: i64 = 2; const ALL_NULL: i64 = ROWCOUNT; const SOME_NULL: i64 = 1; const NO_NULL: i64 = 0; let do_test = - |nulls: i64, expr: &Expr, missing: bool, expect: Option, expect_sql: Option| { + |nulls: i64, pred: &Pred, missing: bool, expect: Option, expect_sql: Option| { assert!((0..=ROWCOUNT).contains(&nulls)); let (min, max) = if nulls < ROWCOUNT { (Scalar::Integer(5), Scalar::Integer(15)) @@ -287,59 +287,59 @@ fn test_sql_where() { ]) }; let filter = DefaultKernelPredicateEvaluator::from(resolver); - let pred = as_data_skipping_predicate(expr).unwrap(); + let skipping_pred = as_data_skipping_predicate(pred).unwrap(); expect_eq!( - filter.eval_expr(&pred, false), + filter.eval_pred(&skipping_pred, false), expect, - "{expr:#?} became {pred:#?} ({min}..{max}, {nulls} nulls)" + "{pred:#?} became {skipping_pred:#?} ({min}..{max}, {nulls} nulls)" ); - let sql_pred = as_sql_data_skipping_predicate(expr).unwrap(); + let skipping_sql_pred = as_sql_data_skipping_predicate(pred).unwrap(); expect_eq!( - filter.eval_expr(&sql_pred, false), + filter.eval_pred(&skipping_sql_pred, false), expect_sql, - "{expr:#?} became {sql_pred:#?} ({min}..{max}, {nulls} nulls)" + "{pred:#?} became {skipping_sql_pred:#?} ({min}..{max}, {nulls} nulls)" ); }; // Sanity tests -- only all-null columns should behave differently between normal and SQL WHERE. const MISSING: bool = true; const PRESENT: bool = false; - let expr = &Expr::lt(TRUE, FALSE); - do_test(ALL_NULL, expr, MISSING, Some(false), Some(false)); + let pred = &Pred::lt(TRUE, FALSE); + do_test(ALL_NULL, pred, MISSING, Some(false), Some(false)); - let expr = &Expr::is_not_null(col.clone()); - do_test(ALL_NULL, expr, PRESENT, Some(false), Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::is_not_null(col.clone()); + do_test(ALL_NULL, pred, PRESENT, Some(false), Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); // SQL WHERE allows a present-but-all-null column to be pruned, but not a missing column. - let expr = &Expr::lt(col.clone(), VAL); - do_test(NO_NULL, expr, PRESENT, Some(true), Some(true)); - do_test(SOME_NULL, expr, PRESENT, Some(true), Some(true)); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::lt(col.clone(), VAL); + do_test(NO_NULL, pred, PRESENT, Some(true), Some(true)); + do_test(SOME_NULL, pred, PRESENT, Some(true), Some(true)); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); // Comparison inside AND works - let expr = &Expr::and(TRUE, Expr::lt(VAL, col.clone())); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::and(TRUE, Pred::lt(VAL, col.clone())); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); // NULL inside AND allows static skipping under SQL semantics - let expr = &Expr::and(NULL, Expr::lt(col.clone(), VAL)); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, Some(false)); + let pred = &Pred::and(NULL, Pred::lt(col.clone(), VAL)); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, Some(false)); // Comparison inside AND inside AND works - let expr = &Expr::and(TRUE, Expr::and(TRUE, Expr::lt(col.clone(), VAL))); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::and(TRUE, Pred::and(TRUE, Pred::lt(col.clone(), VAL))); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); // Comparison inside OR works - let expr = &Expr::or(FALSE, Expr::lt(col.clone(), VAL)); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::or(FALSE, Pred::lt(col.clone(), VAL)); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); // Comparison inside AND inside OR works - let expr = &Expr::or(FALSE, Expr::and(TRUE, Expr::lt(col.clone(), VAL))); - do_test(ALL_NULL, expr, PRESENT, None, Some(false)); - do_test(ALL_NULL, expr, MISSING, None, None); + let pred = &Pred::or(FALSE, Pred::and(TRUE, Pred::lt(col.clone(), VAL))); + do_test(ALL_NULL, pred, PRESENT, None, Some(false)); + do_test(ALL_NULL, pred, MISSING, None, None); } diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index 29df6d8ca..0e4f8527e 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -8,7 +8,9 @@ use super::data_skipping::DataSkippingFilter; use super::{ScanMetadata, Transform}; use crate::actions::get_log_add_schema; use crate::engine_data::{GetData, RowVisitor, TypedGetData as _}; -use crate::expressions::{column_expr, column_name, ColumnName, Expression, ExpressionRef}; +use crate::expressions::{ + column_expr, column_name, ColumnName, Expression, ExpressionRef, PredicateRef, +}; use crate::kernel_predicates::{DefaultKernelPredicateEvaluator, KernelPredicateEvaluator as _}; use crate::log_replay::{FileActionDeduplicator, FileActionKey, LogReplayProcessor}; use crate::scan::{Scalar, TransformExpr}; @@ -39,7 +41,7 @@ use crate::{DeltaResult, Engine, EngineData, Error, ExpressionEvaluator}; /// vector indicating which rows are valid, and any row-level transformation expressions that need /// to be applied to the selected rows. pub(crate) struct ScanLogReplayProcessor { - partition_filter: Option, + partition_filter: Option, data_skipping_filter: Option, add_transform: Arc, logical_schema: SchemaRef, @@ -54,7 +56,7 @@ impl ScanLogReplayProcessor { /// Create a new [`ScanLogReplayProcessor`] instance fn new( engine: &dyn Engine, - physical_predicate: Option<(ExpressionRef, SchemaRef)>, + physical_predicate: Option<(PredicateRef, SchemaRef)>, logical_schema: SchemaRef, transform: Option>, ) -> Self { @@ -82,7 +84,7 @@ struct AddRemoveDedupVisitor<'seen> { selection_vector: Vec, logical_schema: SchemaRef, transform: Option>, - partition_filter: Option, + partition_filter: Option, row_transform_exprs: Vec>, } @@ -100,7 +102,7 @@ impl AddRemoveDedupVisitor<'_> { selection_vector: Vec, logical_schema: SchemaRef, transform: Option>, - partition_filter: Option, + partition_filter: Option, is_log_batch: bool, ) -> AddRemoveDedupVisitor<'_> { AddRemoveDedupVisitor { @@ -389,7 +391,7 @@ pub(crate) fn scan_action_iter( action_iter: impl Iterator, bool)>>, logical_schema: SchemaRef, transform: Option>, - physical_predicate: Option<(ExpressionRef, SchemaRef)>, + physical_predicate: Option<(PredicateRef, SchemaRef)>, ) -> impl Iterator> { ScanLogReplayProcessor::new(engine, physical_predicate, logical_schema, transform) .process_actions_iter(action_iter) @@ -407,7 +409,7 @@ mod tests { run_with_validate_callback, }; use crate::scan::{get_state_info, Scan}; - use crate::Expression; + use crate::Expression as Expr; use crate::{ engine::sync::SyncEngine, schema::{DataType, SchemaRef, StructField, StructType}, @@ -502,17 +504,17 @@ mod tests { fn validate_transform(transform: Option<&ExpressionRef>, expected_date_offset: i32) { assert!(transform.is_some()); - let Expression::Struct(inner) = transform.unwrap().as_ref() else { + let Expr::Struct(inner) = transform.unwrap().as_ref() else { panic!("Transform should always be a struct expr"); }; assert_eq!(inner.len(), 2, "expected two items in transform struct"); - let Expression::Column(ref name) = inner[0] else { + let Expr::Column(ref name) = inner[0] else { panic!("Expected first expression to be a column"); }; assert_eq!(name, &column_name!("value"), "First col should be 'value'"); - let Expression::Literal(ref scalar) = inner[1] else { + let Expr::Literal(ref scalar) = inner[1] else { panic!("Expected second expression to be a literal"); }; assert_eq!( diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index 0df5a1dff..b1a52a30e 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -13,7 +13,8 @@ use crate::actions::deletion_vector::{ }; use crate::actions::{get_log_schema, ADD_NAME, REMOVE_NAME, SIDECAR_NAME}; use crate::engine_data::FilteredEngineData; -use crate::expressions::{ColumnName, Expression, ExpressionRef, ExpressionTransform, Scalar}; +use crate::expressions::transforms::ExpressionTransform; +use crate::expressions::{ColumnName, Expression, ExpressionRef, Predicate, PredicateRef, Scalar}; use crate::kernel_predicates::{DefaultKernelPredicateEvaluator, EmptyColumnResolver}; use crate::log_replay::HasSelectionVector; use crate::scan::state::{DvInfo, Stats}; @@ -36,7 +37,7 @@ pub mod state; pub struct ScanBuilder { snapshot: Arc, schema: Option, - predicate: Option, + predicate: Option, } impl std::fmt::Debug for ScanBuilder { @@ -85,7 +86,7 @@ impl ScanBuilder { /// /// NOTE: The filtering is best-effort and can produce false positives (rows that should should /// have been filtered out but were kept). - pub fn with_predicate(mut self, predicate: impl Into>) -> Self { + pub fn with_predicate(mut self, predicate: impl Into>) -> Self { self.predicate = predicate.into(); self } @@ -122,7 +123,7 @@ impl ScanBuilder { #[derive(Clone, Debug, PartialEq)] pub(crate) enum PhysicalPredicate { - Some(ExpressionRef, SchemaRef), + Some(PredicateRef, SchemaRef), StaticSkipAll, None, } @@ -136,7 +137,7 @@ impl PhysicalPredicate { /// NOTE: It is possible the predicate resolves to FALSE even ignoring column references, /// e.g. `col > 10 AND FALSE`. Such predicates can statically skip the whole query. pub(crate) fn try_new( - predicate: &Expression, + predicate: &Predicate, logical_schema: &Schema, ) -> DeltaResult { if can_statically_skip_all_files(predicate) { @@ -169,7 +170,7 @@ impl PhysicalPredicate { let mut apply_mappings = ApplyColumnMappings { column_mappings: get_referenced_fields.column_mappings, }; - if let Some(predicate) = apply_mappings.transform(predicate) { + if let Some(predicate) = apply_mappings.transform_pred(predicate) { Ok(PhysicalPredicate::Some( Arc::new(predicate.into_owned()), Arc::new(schema.into_owned()), @@ -183,7 +184,7 @@ impl PhysicalPredicate { // Evaluates a static data skipping predicate, ignoring any column references, and returns true if // the predicate allows to statically skip all files. Since this is direct evaluation (not an // expression rewrite), we use a `DefaultKernelPredicateEvaluator` with an empty column resolver. -fn can_statically_skip_all_files(predicate: &Expression) -> bool { +fn can_statically_skip_all_files(predicate: &Predicate) -> bool { use crate::kernel_predicates::KernelPredicateEvaluator as _; let evaluator = DefaultKernelPredicateEvaluator::from(EmptyColumnResolver); evaluator.eval_sql_where(predicate) == Some(false) @@ -238,7 +239,7 @@ struct ApplyColumnMappings { impl<'a> ExpressionTransform<'a> for ApplyColumnMappings { // NOTE: We already verified all column references. But if the map probe ever did fail, the // transform would just delete any expression(s) that reference the invalid column. - fn transform_column(&mut self, name: &'a ColumnName) -> Option> { + fn transform_expr_column(&mut self, name: &'a ColumnName) -> Option> { self.column_mappings .get(name) .map(|physical_name| Cow::Owned(physical_name.clone())) @@ -395,7 +396,7 @@ impl Scan { } /// Get the predicate [`Expression`] of the scan. - pub fn physical_predicate(&self) -> Option { + pub fn physical_predicate(&self) -> Option { if let PhysicalPredicate::Some(ref predicate, _) = self.physical_predicate { Some(predicate.clone()) } else { @@ -835,7 +836,7 @@ mod tests { use std::path::PathBuf; use crate::engine::sync::SyncEngine; - use crate::expressions::column_expr; + use crate::expressions::{column_expr, column_pred, Expression as Expr, Predicate as Pred}; use crate::schema::{ColumnMetadataKey, PrimitiveType}; use crate::Table; @@ -843,19 +844,19 @@ mod tests { #[test] fn test_static_skipping() { - const NULL: Expression = Expression::null_literal(DataType::BOOLEAN); + const NULL: Pred = Pred::null_literal(); let test_cases = [ - (false, column_expr!("a")), - (true, Expression::literal(false)), - (false, Expression::literal(true)), + (false, column_pred!("a")), + (true, Pred::literal(false)), + (false, Pred::literal(true)), (true, NULL), - (true, Expression::and(column_expr!("a"), false)), - (false, Expression::or(column_expr!("a"), true)), - (false, Expression::or(column_expr!("a"), false)), - (false, Expression::lt(column_expr!("a"), 10)), - (false, Expression::lt(Expression::literal(10), 100)), - (true, Expression::gt(Expression::literal(10), 100)), - (true, Expression::and(NULL, column_expr!("a"))), + (true, Pred::and(column_pred!("a"), Pred::literal(false))), + (false, Pred::or(column_pred!("a"), Pred::literal(true))), + (false, Pred::or(column_pred!("a"), Pred::literal(false))), + (false, Pred::lt(column_expr!("a"), Expr::literal(10))), + (false, Pred::lt(Expr::literal(10), Expr::literal(100))), + (true, Pred::gt(Expr::literal(10), Expr::literal(100))), + (true, Pred::and(NULL, column_pred!("a"))), ]; for (should_skip, predicate) in test_cases { assert_eq!( @@ -906,23 +907,20 @@ mod tests { // NOTE: We break several column mapping rules here because they don't matter for this // test. For example, we do not provide field ids, and not all columns have physical names. let test_cases = [ - (Expression::literal(true), Some(PhysicalPredicate::None)), + (Pred::literal(true), Some(PhysicalPredicate::None)), + (Pred::literal(false), Some(PhysicalPredicate::StaticSkipAll)), + (column_pred!("x"), None), // no such column ( - Expression::literal(false), - Some(PhysicalPredicate::StaticSkipAll), - ), - (column_expr!("x"), None), // no such column - ( - column_expr!("a"), + column_pred!("a"), Some(PhysicalPredicate::Some( - column_expr!("a").into(), + column_pred!("a").into(), StructType::new(vec![StructField::nullable("a", DataType::LONG)]).into(), )), ), ( - column_expr!("b"), + column_pred!("b"), Some(PhysicalPredicate::Some( - column_expr!("phys_b").into(), + column_pred!("phys_b").into(), StructType::new(vec![StructField::nullable("phys_b", DataType::LONG) .with_metadata([( ColumnMetadataKey::ColumnMappingPhysicalName.as_ref(), @@ -932,9 +930,9 @@ mod tests { )), ), ( - column_expr!("nested.x"), + column_pred!("nested.x"), Some(PhysicalPredicate::Some( - column_expr!("nested.x").into(), + column_pred!("nested.x").into(), StructType::new(vec![StructField::nullable( "nested", StructType::new(vec![StructField::nullable("x", DataType::LONG)]), @@ -943,9 +941,9 @@ mod tests { )), ), ( - column_expr!("nested.y"), + column_pred!("nested.y"), Some(PhysicalPredicate::Some( - column_expr!("nested.phys_y").into(), + column_pred!("nested.phys_y").into(), StructType::new(vec![StructField::nullable( "nested", StructType::new(vec![StructField::nullable("phys_y", DataType::LONG) @@ -958,9 +956,9 @@ mod tests { )), ), ( - column_expr!("mapped.n"), + column_pred!("mapped.n"), Some(PhysicalPredicate::Some( - column_expr!("phys_mapped.phys_n").into(), + column_pred!("phys_mapped.phys_n").into(), StructType::new(vec![StructField::nullable( "phys_mapped", StructType::new(vec![StructField::nullable("phys_n", DataType::LONG) @@ -977,9 +975,9 @@ mod tests { )), ), ( - Expression::and(column_expr!("mapped.n"), true), + Pred::and(column_pred!("mapped.n"), Pred::literal(true)), Some(PhysicalPredicate::Some( - Expression::and(column_expr!("phys_mapped.phys_n"), true).into(), + Pred::and(column_pred!("phys_mapped.phys_n"), Pred::literal(true)).into(), StructType::new(vec![StructField::nullable( "phys_mapped", StructType::new(vec![StructField::nullable("phys_n", DataType::LONG) @@ -996,7 +994,7 @@ mod tests { )), ), ( - Expression::and(column_expr!("mapped.n"), false), + Pred::and(column_pred!("mapped.n"), Pred::literal(false)), Some(PhysicalPredicate::StaticSkipAll), ), ]; @@ -1151,7 +1149,7 @@ mod tests { // Ineffective predicate pushdown attempted, so the one data file should be returned. let int_col = column_expr!("numeric.ints.int32"); - let value = Expression::literal(1000i32); + let value = Expr::literal(1000i32); let predicate = Arc::new(int_col.clone().gt(value.clone())); let scan = snapshot .clone() @@ -1187,7 +1185,7 @@ mod tests { // // WARNING: https://github.com/delta-io/delta-kernel-rs/issues/434 - This // optimization is currently disabled, so the one data file is still returned. - let predicate = Arc::new(column_expr!("missing").lt(1000i64)); + let predicate = Arc::new(column_expr!("missing").lt(Expr::literal(1000i64))); let scan = snapshot .clone() .scan_builder() @@ -1198,7 +1196,7 @@ mod tests { assert_eq!(data.len(), 1); // Predicate over a logically missing column fails the scan - let predicate = Arc::new(column_expr!("numeric.ints.invalid").lt(1000)); + let predicate = Arc::new(column_expr!("numeric.ints.invalid").lt(Expr::literal(1000))); snapshot .scan_builder() .with_predicate(predicate) diff --git a/kernel/src/table_changes/log_replay.rs b/kernel/src/table_changes/log_replay.rs index 20fc11c6e..7b48c8c9b 100644 --- a/kernel/src/table_changes/log_replay.rs +++ b/kernel/src/table_changes/log_replay.rs @@ -20,7 +20,7 @@ use crate::table_changes::scan_file::{cdf_scan_row_expression, cdf_scan_row_sche use crate::table_changes::{check_cdf_table_properties, ensure_cdf_read_supported}; use crate::table_properties::TableProperties; use crate::utils::require; -use crate::{DeltaResult, Engine, EngineData, Error, ExpressionRef, RowVisitor}; +use crate::{DeltaResult, Engine, EngineData, Error, PredicateRef, RowVisitor}; use itertools::Itertools; @@ -51,7 +51,7 @@ pub(crate) fn table_changes_action_iter( engine: Arc, commit_files: impl IntoIterator, table_schema: SchemaRef, - physical_predicate: Option<(ExpressionRef, SchemaRef)>, + physical_predicate: Option<(PredicateRef, SchemaRef)>, ) -> DeltaResult>> { let filter = DataSkippingFilter::new(engine.as_ref(), physical_predicate).map(Arc::new); let result = commit_files diff --git a/kernel/src/table_changes/log_replay/tests.rs b/kernel/src/table_changes/log_replay/tests.rs index babdde516..64382ff8c 100644 --- a/kernel/src/table_changes/log_replay/tests.rs +++ b/kernel/src/table_changes/log_replay/tests.rs @@ -3,8 +3,7 @@ use super::TableChangesScanMetadata; use crate::actions::deletion_vector::DeletionVectorDescriptor; use crate::actions::{Add, Cdc, Metadata, Protocol, Remove}; use crate::engine::sync::SyncEngine; -use crate::expressions::Scalar; -use crate::expressions::{column_expr, BinaryOperator}; +use crate::expressions::{column_expr, BinaryPredicateOp, Scalar}; use crate::log_segment::LogSegment; use crate::path::ParsedLogPath; use crate::scan::state::DvInfo; @@ -13,7 +12,7 @@ use crate::schema::{DataType, StructField, StructType}; use crate::table_changes::log_replay::LogReplayScanner; use crate::table_features::ReaderFeature; use crate::utils::test_utils::{Action, LocalMockTable}; -use crate::Expression; +use crate::Predicate; use crate::{DeltaResult, Engine, Error, Version}; use itertools::Itertools; @@ -517,8 +516,8 @@ async fn data_skipping_filter() { .await; // Look for actions with id > 4 - let predicate = Expression::binary( - BinaryOperator::GreaterThan, + let predicate = Predicate::binary( + BinaryPredicateOp::GreaterThan, column_expr!("id"), Scalar::from(4), ); diff --git a/kernel/src/table_changes/mod.rs b/kernel/src/table_changes/mod.rs index 86d0f99af..c9fadcc17 100644 --- a/kernel/src/table_changes/mod.rs +++ b/kernel/src/table_changes/mod.rs @@ -18,7 +18,7 @@ //! let schema = table_changes //! .schema() //! .project(&["id", "_commit_version"])?; -//! let predicate = Arc::new(Expression::gt(column_expr!("id"), Scalar::from(10))); +//! let predicate = Arc::new(Predicate::gt(column_expr!("id"), Scalar::from(10))); //! //! // Construct the table changes scan //! let table_changes_scan = table_changes diff --git a/kernel/src/table_changes/physical_to_logical.rs b/kernel/src/table_changes/physical_to_logical.rs index a953048a9..8dabf4e28 100644 --- a/kernel/src/table_changes/physical_to_logical.rs +++ b/kernel/src/table_changes/physical_to_logical.rs @@ -19,8 +19,8 @@ fn get_cdf_columns(scan_file: &CdfScanFile) -> DeltaResult Expression::column([CHANGE_TYPE_COL_NAME]), - CdfScanFileType::Add => ADD_CHANGE_TYPE.into(), - CdfScanFileType::Remove => REMOVE_CHANGE_TYPE.into(), + CdfScanFileType::Add => Expression::literal(ADD_CHANGE_TYPE), + CdfScanFileType::Remove => Expression::literal(REMOVE_CHANGE_TYPE), }; let expressions = [ (CHANGE_TYPE_COL_NAME, change_type), @@ -81,7 +81,7 @@ pub(crate) fn scan_file_physical_schema( mod tests { use std::collections::HashMap; - use crate::expressions::{column_expr, Expression, Scalar}; + use crate::expressions::{column_expr, Expression as Expr, Scalar}; use crate::scan::ColumnType; use crate::schema::{DataType, StructField, StructType}; use crate::table_changes::physical_to_logical::physical_to_logical_expr; @@ -119,20 +119,20 @@ mod tests { ]; let phys_to_logical_expr = physical_to_logical_expr(&scan_file, &logical_schema, &all_fields).unwrap(); - let expected_expr = Expression::struct_from([ + let expected_expr = Expr::struct_from([ column_expr!("id"), Scalar::Long(20).into(), expected_expr, - Expression::literal(42i64), + Expr::literal(42i64), Scalar::TimestampNtz(1234000).into(), // Microsecond is 1000x millisecond ]); assert_eq!(phys_to_logical_expr, expected_expr) }; - let cdc_change_type = Expression::column([CHANGE_TYPE_COL_NAME]); - test(CdfScanFileType::Add, ADD_CHANGE_TYPE.into()); - test(CdfScanFileType::Remove, REMOVE_CHANGE_TYPE.into()); + let cdc_change_type = Expr::column([CHANGE_TYPE_COL_NAME]); + test(CdfScanFileType::Add, Expr::literal(ADD_CHANGE_TYPE)); + test(CdfScanFileType::Remove, Expr::literal(REMOVE_CHANGE_TYPE)); test(CdfScanFileType::Cdc, cdc_change_type); } } diff --git a/kernel/src/table_changes/scan.rs b/kernel/src/table_changes/scan.rs index b9bed794d..8242c133c 100644 --- a/kernel/src/table_changes/scan.rs +++ b/kernel/src/table_changes/scan.rs @@ -10,7 +10,7 @@ use crate::actions::deletion_vector::split_vector; use crate::scan::state::GlobalScanState; use crate::scan::{ColumnType, PhysicalPredicate, ScanResult}; use crate::schema::{SchemaRef, StructType}; -use crate::{DeltaResult, Engine, ExpressionRef, FileMeta}; +use crate::{DeltaResult, Engine, FileMeta, PredicateRef}; use super::log_replay::{table_changes_action_iter, TableChangesScanMetadata}; use super::physical_to_logical::{physical_to_logical_expr, scan_file_physical_schema}; @@ -62,7 +62,7 @@ pub struct TableChangesScan { /// .schema() /// .project(&["id", "_commit_version"]) /// .unwrap(); -/// let predicate = Arc::new(Expression::gt(column_expr!("id"), Scalar::from(10))); +/// let predicate = Arc::new(Predicate::gt(column_expr!("id"), Scalar::from(10))); /// let scan = table_changes /// .into_scan_builder() /// .with_schema(schema) @@ -73,7 +73,7 @@ pub struct TableChangesScan { pub struct TableChangesScanBuilder { table_changes: Arc, schema: Option, - predicate: Option, + predicate: Option, } impl TableChangesScanBuilder { @@ -103,7 +103,7 @@ impl TableChangesScanBuilder { /// /// NOTE: The filtering is best-effort and can produce false positives (rows that should should /// have been filtered out but were kept). - pub fn with_predicate(mut self, predicate: impl Into>) -> Self { + pub fn with_predicate(mut self, predicate: impl Into>) -> Self { self.predicate = predicate.into(); self } @@ -222,8 +222,8 @@ impl TableChangesScan { &self.logical_schema } - /// Get the predicate [`ExpressionRef`] of the scan. - fn physical_predicate(&self) -> Option { + /// Get the predicate [`PredicateRef`] of the scan. + fn physical_predicate(&self) -> Option { if let PhysicalPredicate::Some(ref predicate, _) = self.physical_predicate { Some(predicate.clone()) } else { @@ -276,7 +276,7 @@ fn read_scan_file( resolved_scan_file: ResolvedCdfScanFile, global_state: &GlobalScanState, all_fields: &[ColumnType], - physical_predicate: Option, + physical_predicate: Option, ) -> DeltaResult>> { let ResolvedCdfScanFile { scan_file, @@ -362,7 +362,7 @@ mod tests { use crate::scan::{ColumnType, PhysicalPredicate}; use crate::schema::{DataType, StructField, StructType}; use crate::table_changes::COMMIT_VERSION_COL_NAME; - use crate::{Expression, Table}; + use crate::{Predicate, Table}; #[test] fn simple_table_changes_scan_builder() { @@ -402,7 +402,7 @@ mod tests { .schema() .project(&["id", COMMIT_VERSION_COL_NAME]) .unwrap(); - let predicate = Arc::new(Expression::gt(column_expr!("id"), Scalar::from(10))); + let predicate = Arc::new(Predicate::gt(column_expr!("id"), Scalar::from(10))); let scan = table_changes .into_scan_builder() .with_schema(schema) diff --git a/kernel/src/table_changes/scan_file.rs b/kernel/src/table_changes/scan_file.rs index c0a8ae490..b9ff2e97d 100644 --- a/kernel/src/table_changes/scan_file.rs +++ b/kernel/src/table_changes/scan_file.rs @@ -230,8 +230,8 @@ pub(crate) fn cdf_scan_row_expression(commit_timestamp: i64, commit_number: i64) column_expr!("cdc.path"), Expression::struct_from([column_expr!("cdc.partitionValues")]), ]), - commit_timestamp.into(), - commit_number.into(), + Expression::literal(commit_timestamp), + Expression::literal(commit_number), ]) } diff --git a/kernel/tests/cdf.rs b/kernel/tests/cdf.rs index 069018951..e2f555c6f 100644 --- a/kernel/tests/cdf.rs +++ b/kernel/tests/cdf.rs @@ -6,7 +6,7 @@ use delta_kernel::engine::sync::SyncEngine; use itertools::Itertools; use delta_kernel::engine::arrow_data::ArrowEngineData; -use delta_kernel::{DeltaResult, Error, ExpressionRef, Table, Version}; +use delta_kernel::{DeltaResult, Error, PredicateRef, Table, Version}; mod common; use common::{load_test_data, to_arrow}; @@ -15,7 +15,7 @@ fn read_cdf_for_table( test_name: impl AsRef, start_version: Version, end_version: impl Into>, - predicate: impl Into>, + predicate: impl Into>, ) -> DeltaResult> { let test_dir = load_test_data("tests/data", test_name.as_ref()).unwrap(); let test_path = test_dir.path().join(test_name.as_ref()); diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index ebd57b4f4..247190bb8 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -8,7 +8,10 @@ use delta_kernel::arrow::datatypes::SchemaRef as ArrowSchemaRef; use delta_kernel::engine::arrow_data::ArrowEngineData; use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; -use delta_kernel::expressions::{column_expr, BinaryOperator, Expression, ExpressionRef}; +use delta_kernel::expressions::{ + column_expr, column_pred, BinaryPredicateOp, Expression as Expr, ExpressionRef, + Predicate as Pred, +}; use delta_kernel::parquet::file::properties::{EnabledStatistics, WriterProperties}; use delta_kernel::scan::state::{transform_to_logical, DvInfo, Stats}; use delta_kernel::scan::Scan; @@ -251,7 +254,7 @@ async fn stats() -> Result<(), Box> { // // NOTE: For cases that match both batch1 and batch2, we list batch2 first because log replay // returns most recently added files first. - use BinaryOperator::{ + use BinaryPredicateOp::{ Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual, }; let test_cases: Vec<(_, i32, _)> = vec![ @@ -287,7 +290,7 @@ async fn stats() -> Result<(), Box> { (NotEqual, 8, vec![&batch2, &batch1]), ]; for (op, value, expected_batches) in test_cases { - let predicate = Expression::binary(op, column_expr!("id"), value); + let predicate = Pred::binary(op, column_expr!("id"), Expr::literal(value)); let scan = snapshot .clone() .scan_builder() @@ -422,7 +425,7 @@ fn read_with_scan_metadata( fn read_table_data( path: &str, select_cols: Option<&[&str]>, - predicate: Option, + predicate: Option, mut expected: Vec, ) -> Result<(), Box> { let path = std::fs::canonicalize(PathBuf::from(path))?; @@ -465,7 +468,7 @@ fn read_table_data( fn read_table_data_str( path: &str, select_cols: Option<&[&str]>, - predicate: Option, + predicate: Option, expected: Vec<&str>, ) -> Result<(), Box> { read_table_data( @@ -584,24 +587,27 @@ fn table_for_letters(letters: &[char]) -> Vec { fn predicate_on_number() -> Result<(), Box> { let cases = vec![ ( - column_expr!("number").lt(4i64), + column_expr!("number").lt(Expr::literal(4i64)), table_for_numbers(vec![1, 2, 3]), ), ( - column_expr!("number").le(4i64), + column_expr!("number").le(Expr::literal(4i64)), table_for_numbers(vec![1, 2, 3, 4]), ), ( - column_expr!("number").gt(4i64), + column_expr!("number").gt(Expr::literal(4i64)), table_for_numbers(vec![5, 6]), ), ( - column_expr!("number").ge(4i64), + column_expr!("number").ge(Expr::literal(4i64)), table_for_numbers(vec![4, 5, 6]), ), - (column_expr!("number").eq(4i64), table_for_numbers(vec![4])), ( - column_expr!("number").ne(4i64), + column_expr!("number").eq(Expr::literal(4i64)), + table_for_numbers(vec![4]), + ), + ( + column_expr!("number").ne(Expr::literal(4i64)), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ]; @@ -619,7 +625,7 @@ fn predicate_on_number() -> Result<(), Box> { #[test] fn predicate_on_letter() -> Result<(), Box> { - // Test basic column pruning. Note that the actual expression machinery is already well-tested, + // Test basic column pruning. Note that the actual predicate machinery is already well-tested, // so we're just testing wiring here. let null_row_table: Vec = vec![ "+--------+--------+", @@ -639,21 +645,27 @@ fn predicate_on_letter() -> Result<(), Box> { table_for_letters(&['a', 'b', 'c', 'e']), ), ( - column_expr!("letter").lt("c"), + column_expr!("letter").lt(Expr::literal("c")), table_for_letters(&['a', 'b']), ), ( - column_expr!("letter").le("c"), + column_expr!("letter").le(Expr::literal("c")), table_for_letters(&['a', 'b', 'c']), ), - (column_expr!("letter").gt("c"), table_for_letters(&['e'])), ( - column_expr!("letter").ge("c"), + column_expr!("letter").gt(Expr::literal("c")), + table_for_letters(&['e']), + ), + ( + column_expr!("letter").ge(Expr::literal("c")), table_for_letters(&['c', 'e']), ), - (column_expr!("letter").eq("c"), table_for_letters(&['c'])), ( - column_expr!("letter").ne("c"), + column_expr!("letter").eq(Expr::literal("c")), + table_for_letters(&['c']), + ), + ( + column_expr!("letter").ne(Expr::literal("c")), table_for_letters(&['a', 'b', 'e']), ), ]; @@ -691,27 +703,27 @@ fn predicate_on_letter_and_number() -> Result<(), Box> { let cases = vec![ ( - Expression::or( + Pred::or( // No pruning power - column_expr!("letter").gt("a"), - column_expr!("number").gt(3i64), + column_expr!("letter").gt(Expr::literal("a")), + column_expr!("number").gt(Expr::literal(3i64)), ), full_table, ), ( - Expression::and( - column_expr!("letter").gt("a"), // numbers 2, 3, 5 - column_expr!("number").gt(3i64), // letters a, e + Pred::and( + column_expr!("letter").gt(Expr::literal("a")), // numbers 2, 3, 5 + column_expr!("number").gt(Expr::literal(3i64)), // letters a, e ), table_for_letters(&['e']), ), ( - Expression::and( - column_expr!("letter").gt("a"), // numbers 2, 3, 5 - Expression::or( + Pred::and( + column_expr!("letter").gt(Expr::literal("a")), // numbers 2, 3, 5 + Pred::or( // No pruning power - column_expr!("letter").eq("c"), - column_expr!("number").eq(3i64), + column_expr!("letter").eq(Expr::literal("c")), + column_expr!("number").eq(Expr::literal(3i64)), ), ), table_for_letters(&['b', 'c', 'e']), @@ -733,27 +745,27 @@ fn predicate_on_letter_and_number() -> Result<(), Box> { fn predicate_on_number_not() -> Result<(), Box> { let cases = vec![ ( - Expression::not(column_expr!("number").lt(4i64)), + Pred::not(column_expr!("number").lt(Expr::literal(4i64))), table_for_numbers(vec![4, 5, 6]), ), ( - Expression::not(column_expr!("number").le(4i64)), + Pred::not(column_expr!("number").le(Expr::literal(4i64))), table_for_numbers(vec![5, 6]), ), ( - Expression::not(column_expr!("number").gt(4i64)), + Pred::not(column_expr!("number").gt(Expr::literal(4i64))), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::not(column_expr!("number").ge(4i64)), + Pred::not(column_expr!("number").ge(Expr::literal(4i64))), table_for_numbers(vec![1, 2, 3]), ), ( - Expression::not(column_expr!("number").eq(4i64)), + Pred::not(column_expr!("number").eq(Expr::literal(4i64))), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ( - Expression::not(column_expr!("number").ne(4i64)), + Pred::not(column_expr!("number").ne(Expr::literal(4i64))), table_for_numbers(vec![4]), ), ]; @@ -781,9 +793,9 @@ fn predicate_on_number_with_not_null() -> Result<(), Box> read_table_data_str( "./tests/data/basic_partitioned", Some(&["a_float", "number"]), - Some(Expression::and( + Some(Pred::and( column_expr!("number").is_not_null(), - column_expr!("number").lt(Expression::literal(3i64)), + column_expr!("number").lt(Expr::literal(3i64)), )), expected, )?; @@ -860,30 +872,30 @@ fn mixed_not_null() -> Result<(), Box> { fn and_or_predicates() -> Result<(), Box> { let cases = vec![ ( - Expression::and( - column_expr!("number").gt(4i64), - column_expr!("a_float").gt(5.5), + Pred::and( + column_expr!("number").gt(Expr::literal(4i64)), + column_expr!("a_float").gt(Expr::literal(5.5)), ), table_for_numbers(vec![6]), ), ( - Expression::and( - column_expr!("number").gt(4i64), - Expression::not(column_expr!("a_float").gt(5.5)), + Pred::and( + column_expr!("number").gt(Expr::literal(4i64)), + Pred::not(column_expr!("a_float").gt(Expr::literal(5.5))), ), table_for_numbers(vec![5]), ), ( - Expression::or( - column_expr!("number").gt(4i64), - column_expr!("a_float").gt(5.5), + Pred::or( + column_expr!("number").gt(Expr::literal(4i64)), + column_expr!("a_float").gt(Expr::literal(5.5)), ), table_for_numbers(vec![5, 6]), ), ( - Expression::or( - column_expr!("number").gt(4i64), - Expression::not(column_expr!("a_float").gt(5.5)), + Pred::or( + column_expr!("number").gt(Expr::literal(4i64)), + Pred::not(column_expr!("a_float").gt(Expr::literal(5.5))), ), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), @@ -903,30 +915,30 @@ fn and_or_predicates() -> Result<(), Box> { fn not_and_or_predicates() -> Result<(), Box> { let cases = vec![ ( - Expression::not(Expression::and( - column_expr!("number").gt(4i64), - column_expr!("a_float").gt(5.5), + Pred::not(Pred::and( + column_expr!("number").gt(Expr::literal(4i64)), + column_expr!("a_float").gt(Expr::literal(5.5)), )), table_for_numbers(vec![1, 2, 3, 4, 5]), ), ( - Expression::not(Expression::and( - column_expr!("number").gt(4i64), - Expression::not(column_expr!("a_float").gt(5.5)), + Pred::not(Pred::and( + column_expr!("number").gt(Expr::literal(4i64)), + Pred::not(column_expr!("a_float").gt(Expr::literal(5.5))), )), table_for_numbers(vec![1, 2, 3, 4, 6]), ), ( - Expression::not(Expression::or( - column_expr!("number").gt(4i64), - column_expr!("a_float").gt(5.5), + Pred::not(Pred::or( + column_expr!("number").gt(Expr::literal(4i64)), + column_expr!("a_float").gt(Expr::literal(5.5)), )), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::not(Expression::or( - column_expr!("number").gt(4i64), - Expression::not(column_expr!("a_float").gt(5.5)), + Pred::not(Pred::or( + column_expr!("number").gt(Expr::literal(4i64)), + Pred::not(column_expr!("a_float").gt(Expr::literal(5.5))), )), vec![], ), @@ -944,37 +956,35 @@ fn not_and_or_predicates() -> Result<(), Box> { #[test] fn invalid_skips_none_predicates() -> Result<(), Box> { - let empty_struct = Expression::struct_from(vec![]); + let empty_struct = Expr::struct_from(vec![]); let cases = vec![ - (Expression::literal(false), table_for_numbers(vec![])), + (Pred::literal(false), table_for_numbers(vec![])), ( - Expression::and(column_expr!("number"), false), + Pred::and(column_pred!("number"), Pred::literal(false)), table_for_numbers(vec![]), ), ( - Expression::literal(true), + Pred::literal(true), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::literal(3i64), + Pred::from_expr(Expr::literal(3i64)), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - column_expr!("number").distinct(3i64), + column_expr!("number").distinct(Expr::literal(3i64)), table_for_numbers(vec![1, 2, 4, 5, 6]), ), ( - column_expr!("number").distinct(Expression::null_literal(DataType::LONG)), + column_expr!("number").distinct(Expr::null_literal(DataType::LONG)), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(column_expr!("number").distinct(3i64)), + Pred::not(column_expr!("number").distinct(Expr::literal(3i64))), table_for_numbers(vec![3]), ), ( - Expression::not( - column_expr!("number").distinct(Expression::null_literal(DataType::LONG)), - ), + Pred::not(column_expr!("number").distinct(Expr::null_literal(DataType::LONG))), table_for_numbers(vec![]), ), ( @@ -982,15 +992,15 @@ fn invalid_skips_none_predicates() -> Result<(), Box> { table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(column_expr!("number").gt(empty_struct.clone())), + Pred::not(column_expr!("number").gt(empty_struct.clone())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ]; - for (expr, expected) in cases.into_iter() { + for (pred, expected) in cases.into_iter() { read_table_data( "./tests/data/basic_partitioned", Some(&["a_float", "number"]), - Some(expr), + Some(pred), expected, )?; } @@ -1016,7 +1026,7 @@ fn with_predicate_and_removes() -> Result<(), Box> { read_table_data_str( "./tests/data/table-with-dv-small/", None, - Some(Expression::gt(column_expr!("value"), 3)), + Some(Pred::gt(column_expr!("value"), Expr::literal(3))), expected, )?; Ok(()) @@ -1059,7 +1069,7 @@ async fn predicate_on_non_nullable_partition_column() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box