Skip to content

Commit

Permalink
Merge pull request #110 from ryan-johnson-databricks/frj-simpler-and
Browse files Browse the repository at this point in the history
Simplify kernel AND/OR logic
  • Loading branch information
ryan-johnson-databricks authored Jan 29, 2024
2 parents da01d18 + 58ec3ba commit d09f814
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 135 deletions.
10 changes: 5 additions & 5 deletions kernel/src/client/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ fn evaluate_expression(expression: &Expression, batch: &RecordBatch) -> DeltaRes
})
}
VariadicOperation { op, exprs } => {
let reducer = match op {
VariadicOperator::And => and,
VariadicOperator::Or => or,
type Operation = fn(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>;
let (reducer, default): (Operation, _) = match op {
VariadicOperator::And => (and, true),
VariadicOperator::Or => (or, false),
};
exprs
.iter()
Expand All @@ -192,8 +193,7 @@ fn evaluate_expression(expression: &Expression, batch: &RecordBatch) -> DeltaRes
Ok(reducer(downcast_to_bool(&l?)?, downcast_to_bool(&r?)?)
.map(wrap_comparison_result)?)
})
.transpose()?
.ok_or(Error::Generic("empty expression".to_string()))
.unwrap_or_else(|| evaluate_expression(&Expression::literal(default), batch))
}
}
}
Expand Down
80 changes: 49 additions & 31 deletions kernel/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ impl Expression {
Self::Literal(value.into())
}

/// Creates a new unary expression OP expr
pub fn unary(op: UnaryOperator, expr: impl Into<Expression>) -> Self {
Self::UnaryOperation {
op,
expr: Box::new(expr.into()),
}
}

/// Creates a new binary expression lhs OP rhs
pub fn binary(
op: BinaryOperator,
lhs: impl Into<Expression>,
Expand All @@ -181,18 +183,27 @@ impl Expression {
}
}

pub fn variadic(op: VariadicOperator, other: impl IntoIterator<Item = Self>) -> Self {
let mut exprs = other.into_iter().collect::<Vec<_>>();
if exprs.is_empty() {
// TODO this might break if we introduce new variadic operators?
return Self::literal(matches!(op, VariadicOperator::And));
}
if exprs.len() == 1 {
return exprs.pop().unwrap();
}
/// Creates a new variadic expression OP(exprs...)
pub fn variadic(op: VariadicOperator, exprs: impl IntoIterator<Item = Self>) -> Self {
let exprs = exprs.into_iter().collect::<Vec<_>>();
Self::VariadicOperation { op, exprs }
}

/// Creates a new expression AND(exprs...)
pub fn and_from(exprs: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::And, exprs)
}

/// Creates a new expression OR(exprs...)
pub fn or_from(exprs: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::Or, exprs)
}

/// Create a new expression `self IS NULL`
pub fn is_null(self) -> Self {
Self::unary(UnaryOperator::IsNull, self)
}

/// Create a new expression `self == other`
pub fn eq(self, other: Self) -> Self {
Self::binary(BinaryOperator::Equal, self, other)
Expand All @@ -203,11 +214,21 @@ impl Expression {
Self::binary(BinaryOperator::NotEqual, self, other)
}

/// Create a new expression `self <= other`
pub fn le(self, other: Self) -> Self {
Self::binary(BinaryOperator::LessThanOrEqual, self, other)
}

/// Create a new expression `self < other`
pub fn lt(self, other: Self) -> Self {
Self::binary(BinaryOperator::LessThan, self, other)
}

/// Create a new expression `self >= other`
pub fn ge(self, other: Self) -> Self {
Self::binary(BinaryOperator::GreaterThanOrEqual, self, other)
}

/// Create a new expression `self > other`
pub fn gt(self, other: Self) -> Self {
Self::binary(BinaryOperator::GreaterThan, self, other)
Expand All @@ -225,22 +246,12 @@ impl Expression {

/// Create a new expression `self AND other`
pub fn and(self, other: Self) -> Self {
self.and_many([other])
}

/// Create a new expression `self AND others`
pub fn and_many(self, other: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::And, std::iter::once(self).chain(other))
}

/// Create a new expression `self AND other`
pub fn or(self, other: Self) -> Self {
self.or_many([other])
Self::and_from([self, other])
}

/// Create a new expression `self OR other`
pub fn or_many(self, other: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::Or, std::iter::once(self).chain(other))
pub fn or(self, other: Self) -> Self {
Self::or_from([self, other])
}

fn walk(&self) -> impl Iterator<Item = &Self> + '_ {
Expand All @@ -257,17 +268,23 @@ impl Expression {
Self::UnaryOperation { expr, .. } => {
stack.push(expr);
}
Self::VariadicOperation { op, exprs } => match op {
VariadicOperator::And | VariadicOperator::Or => {
stack.extend(exprs.iter());
}
},
Self::VariadicOperation { exprs, .. } => {
stack.extend(exprs.iter());
}
}
Some(expr)
})
}
}

impl std::ops::Not for Expression {
type Output = Self;

fn not(self) -> Self {
Self::unary(UnaryOperator::Not, self)
}
}

impl std::ops::Add<Expression> for Expression {
type Output = Self;

Expand All @@ -279,23 +296,23 @@ impl std::ops::Add<Expression> for Expression {
impl std::ops::Sub<Expression> for Expression {
type Output = Self;

fn sub(self, rhs: Expression) -> Self::Output {
fn sub(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Minus, self, rhs)
}
}

impl std::ops::Mul<Expression> for Expression {
type Output = Self;

fn mul(self, rhs: Expression) -> Self::Output {
fn mul(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Multiply, self, rhs)
}
}

impl std::ops::Div<Expression> for Expression {
type Output = Self;

fn div(self, rhs: Expression) -> Self::Output {
fn div(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Divide, self, rhs)
}
}
Expand Down Expand Up @@ -326,7 +343,8 @@ mod tests {
"AND(Column(x) >= 2, Column(x) <= 10)",
),
(
col_ref.clone().gt_eq(Expr::literal(2)).and_many([
Expr::and_from([
col_ref.clone().gt_eq(Expr::literal(2)),
col_ref.clone().lt_eq(Expr::literal(10)),
col_ref.clone().lt_eq(Expr::literal(100)),
]),
Expand Down
129 changes: 30 additions & 99 deletions kernel/src/scan/data_skipping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ fn commute(op: &BinaryOperator) -> Option<BinaryOperator> {
fn as_data_skipping_predicate(expr: &Expr) -> Option<Expr> {
use BinaryOperator::*;
use Expr::*;
use VariadicOperator::*;

match expr {
BinaryOperation { op, left, right } => {
Expand All @@ -58,25 +57,17 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option<Expr> {
GreaterThan | GreaterThanOrEqual => "maxValues",
Equal => {
let exprs = [
Expr::binary(LessThanOrEqual, Column(col.clone()), Literal(val.clone())),
Expr::binary(LessThanOrEqual, Literal(val.clone()), Column(col.clone())),
Expr::le(Column(col.clone()), Literal(val.clone())),
Expr::le(Literal(val.clone()), Column(col.clone())),
];
return as_data_skipping_predicate(&Expr::variadic(And, exprs));
return as_data_skipping_predicate(&Expr::and_from(exprs));
}
NotEqual => {
let exprs = [
Expr::binary(
GreaterThan,
Column(format!("minValues.{}", col)),
Literal(val.clone()),
),
Expr::binary(
LessThan,
Column(format!("maxValues.{}", col)),
Literal(val.clone()),
),
Expr::gt(Column(format!("minValues.{}", col)), Literal(val.clone())),
Expr::lt(Column(format!("maxValues.{}", col)), Literal(val.clone())),
];
return Some(Expr::variadic(Or, exprs));
return Some(Expr::or_from(exprs));
}
_ => return None, // unsupported operation
};
Expand Down Expand Up @@ -260,123 +251,63 @@ mod tests {
let cases = [
(
column.clone().lt(lit_int.clone()),
Expr::binary(BinaryOperator::LessThan, min_col.clone(), lit_int.clone()),
Expr::lt(min_col.clone(), lit_int.clone()),
),
(
lit_int.clone().lt(column.clone()),
Expr::binary(
BinaryOperator::GreaterThan,
max_col.clone(),
lit_int.clone(),
),
Expr::gt(max_col.clone(), lit_int.clone()),
),
(
column.clone().gt(lit_int.clone()),
Expr::binary(
BinaryOperator::GreaterThan,
max_col.clone(),
lit_int.clone(),
),
Expr::gt(max_col.clone(), lit_int.clone()),
),
(
lit_int.clone().gt(column.clone()),
Expr::binary(BinaryOperator::LessThan, min_col.clone(), lit_int.clone()),
Expr::lt(min_col.clone(), lit_int.clone()),
),
(
column.clone().lt_eq(lit_int.clone()),
Expr::binary(
BinaryOperator::LessThanOrEqual,
min_col.clone(),
lit_int.clone(),
),
Expr::le(min_col.clone(), lit_int.clone()),
),
(
lit_int.clone().lt_eq(column.clone()),
Expr::binary(
BinaryOperator::GreaterThanOrEqual,
max_col.clone(),
lit_int.clone(),
),
Expr::ge(max_col.clone(), lit_int.clone()),
),
(
column.clone().gt_eq(lit_int.clone()),
Expr::binary(
BinaryOperator::GreaterThanOrEqual,
max_col.clone(),
lit_int.clone(),
),
Expr::ge(max_col.clone(), lit_int.clone()),
),
(
lit_int.clone().gt_eq(column.clone()),
Expr::binary(
BinaryOperator::LessThanOrEqual,
min_col.clone(),
lit_int.clone(),
),
Expr::le(min_col.clone(), lit_int.clone()),
),
(
column.clone().eq(lit_int.clone()),
Expr::variadic(
VariadicOperator::And,
[
Expr::binary(
BinaryOperator::LessThanOrEqual,
min_col.clone(),
lit_int.clone(),
),
Expr::binary(
BinaryOperator::GreaterThanOrEqual,
max_col.clone(),
lit_int.clone(),
),
],
),
Expr::and_from([
Expr::le(min_col.clone(), lit_int.clone()),
Expr::ge(max_col.clone(), lit_int.clone()),
]),
),
(
lit_int.clone().eq(column.clone()),
Expr::variadic(
VariadicOperator::And,
[
Expr::binary(
BinaryOperator::LessThanOrEqual,
min_col.clone(),
lit_int.clone(),
),
Expr::binary(
BinaryOperator::GreaterThanOrEqual,
max_col.clone(),
lit_int.clone(),
),
],
),
Expr::and_from([
Expr::le(min_col.clone(), lit_int.clone()),
Expr::ge(max_col.clone(), lit_int.clone()),
]),
),
(
column.clone().ne(lit_int.clone()),
Expr::variadic(
VariadicOperator::Or,
[
Expr::binary(
BinaryOperator::GreaterThan,
min_col.clone(),
lit_int.clone(),
),
Expr::binary(BinaryOperator::LessThan, max_col.clone(), lit_int.clone()),
],
),
Expr::or_from([
Expr::gt(min_col.clone(), lit_int.clone()),
Expr::lt(max_col.clone(), lit_int.clone()),
]),
),
(
lit_int.clone().ne(column.clone()),
Expr::variadic(
VariadicOperator::Or,
[
Expr::binary(
BinaryOperator::GreaterThan,
min_col.clone(),
lit_int.clone(),
),
Expr::binary(BinaryOperator::LessThan, max_col.clone(), lit_int.clone()),
],
),
Expr::or_from([
Expr::gt(min_col.clone(), lit_int.clone()),
Expr::lt(max_col.clone(), lit_int.clone()),
]),
),
];

Expand Down

0 comments on commit d09f814

Please sign in to comment.