Skip to content

Commit

Permalink
Merge pull request #270 from hntd187/in-not-in
Browse files Browse the repository at this point in the history
First pass on IN/Not In
  • Loading branch information
hntd187 authored Aug 19, 2024
2 parents e4c07dd + a69a95b commit 1459b30
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ version = "0.3.0"
arrow = { version = "^52.0" }
arrow-arith = { version = "^52.0" }
arrow-array = { version = "^52.0" }
arrow-buffer = { version = "^52.0" }
arrow-cast = { version = "^52.0" }
arrow-data = { version = "^52.0" }
arrow-ord = { version = "^52.0" }
Expand Down
2 changes: 2 additions & 0 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ pub enum KernelError {
InvalidDecimalError,
InvalidStructDataError,
InternalError,
InvalidExpression,
}

impl From<Error> for KernelError {
Expand Down Expand Up @@ -372,6 +373,7 @@ impl From<Error> for KernelError {
source,
backtrace: _,
} => Self::from(*source),
Error::InvalidExpressionEvaluation(_) => KernelError::InvalidExpression,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ delta_kernel_derive = { path = "../derive-macros", version = "0.3.0" }
visibility = "0.1.0"

# Used in default engine
arrow-buffer = { workspace = true, optional = true }
arrow-array = { workspace = true, optional = true, features = ["chrono-tz"] }
arrow-select = { workspace = true, optional = true }
arrow-arith = { workspace = true, optional = true }
Expand Down Expand Up @@ -78,6 +79,7 @@ default-engine = [
"arrow-conversion",
"arrow-expression",
"arrow-array",
"arrow-buffer",
"arrow-cast",
"arrow-json",
"arrow-schema",
Expand Down
253 changes: 248 additions & 5 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@ use std::sync::Arc;
use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene};
use arrow_arith::numeric::{add, div, mul, sub};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch,
StringArray, StructArray, TimestampMicrosecondArray,
};
use arrow_buffer::OffsetBuffer;
use arrow_ord::cmp::{distinct, eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_ord::comparison::in_list_utf8;
use arrow_schema::{
ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, Schema as ArrowSchema,
ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, IntervalUnit,
Schema as ArrowSchema, TimeUnit,
};
use arrow_select::concat::concat;
use itertools::Itertools;

use super::arrow_conversion::LIST_ARRAY_ROOT;
use crate::engine::arrow_data::ArrowEngineData;
use crate::engine::arrow_utils::ensure_data_types;
use crate::engine::arrow_utils::prim_array_cmp;
use crate::error::{DeltaResult, Error};
use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator};
use crate::schema::{DataType, PrimitiveType, SchemaRef};
Expand Down Expand Up @@ -67,6 +73,21 @@ impl Scalar {
.try_collect()?;
Arc::new(StructArray::try_new(fields, arrays, None)?)
}
Array(data) => {
#[allow(deprecated)]
let values = data.array_elements();
let vecs: Vec<_> = values.iter().map(|v| v.to_array(num_rows)).try_collect()?;
let values: Vec<_> = vecs.iter().map(|x| x.as_ref()).collect();
let offsets: Vec<_> = vecs.iter().map(|v| v.len()).collect();
let offset_buffer = OffsetBuffer::from_lengths(offsets);
let field = ArrowField::try_from(data.array_type())?;
Arc::new(ListArray::new(
Arc::new(field),
offset_buffer,
concat(values.as_slice())?,
None,
))
}
Null(data_type) => match data_type {
DataType::Primitive(primitive) => match primitive {
PrimitiveType::Byte => Arc::new(Int8Array::new_null(num_rows)),
Expand Down Expand Up @@ -168,7 +189,6 @@ fn evaluate_expression(
) -> DeltaResult<ArrayRef> {
use BinaryOperator::*;
use Expression::*;

match (expression, result_type) {
(Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?),
(Column(name), _) => {
Expand Down Expand Up @@ -216,6 +236,88 @@ fn evaluate_expression(
UnaryOperator::IsNull => Arc::new(is_null(&arr)?),
})
}
(
BinaryOperation {
op: In,
left,
right,
},
_,
) => match (left.as_ref(), right.as_ref()) {
(Literal(_), Column(c)) => {
let list_type = batch.column_by_name(c).map(|c| c.data_type());
if !matches!(
list_type,
Some(ArrowDataType::List(_)) | Some(ArrowDataType::FixedSizeList(_, _))
) {
return Err(Error::InvalidExpressionEvaluation(format!(
"Right side column: {c} is not a list or a fixed size list"
)));
}
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;
if let Some(string_arr) = left_arr.as_string_opt::<i32>() {
return in_list_utf8(string_arr, right_arr.as_list::<i32>())
.map(wrap_comparison_result)
.map_err(Error::generic_err);
}
prim_array_cmp! {
left_arr, right_arr,
(ArrowDataType::Int8, Int8Type),
(ArrowDataType::Int16, Int16Type),
(ArrowDataType::Int32, Int32Type),
(ArrowDataType::Int64, Int64Type),
(ArrowDataType::UInt8, UInt8Type),
(ArrowDataType::UInt16, UInt16Type),
(ArrowDataType::UInt32, UInt32Type),
(ArrowDataType::UInt64, UInt64Type),
(ArrowDataType::Float16, Float16Type),
(ArrowDataType::Float32, Float32Type),
(ArrowDataType::Float64, Float64Type),
(ArrowDataType::Timestamp(TimeUnit::Second, _), TimestampSecondType),
(ArrowDataType::Timestamp(TimeUnit::Millisecond, _), TimestampMillisecondType),
(ArrowDataType::Timestamp(TimeUnit::Microsecond, _), TimestampMicrosecondType),
(ArrowDataType::Timestamp(TimeUnit::Nanosecond, _), TimestampNanosecondType),
(ArrowDataType::Date32, Date32Type),
(ArrowDataType::Date64, Date64Type),
(ArrowDataType::Time32(TimeUnit::Second), Time32SecondType),
(ArrowDataType::Time32(TimeUnit::Millisecond), Time32MillisecondType),
(ArrowDataType::Time64(TimeUnit::Microsecond), Time64MicrosecondType),
(ArrowDataType::Time64(TimeUnit::Nanosecond), Time64NanosecondType),
(ArrowDataType::Duration(TimeUnit::Second), DurationSecondType),
(ArrowDataType::Duration(TimeUnit::Millisecond), DurationMillisecondType),
(ArrowDataType::Duration(TimeUnit::Microsecond), DurationMicrosecondType),
(ArrowDataType::Duration(TimeUnit::Nanosecond), DurationNanosecondType),
(ArrowDataType::Interval(IntervalUnit::DayTime), IntervalDayTimeType),
(ArrowDataType::Interval(IntervalUnit::YearMonth), IntervalYearMonthType),
(ArrowDataType::Interval(IntervalUnit::MonthDayNano), IntervalMonthDayNanoType),
(ArrowDataType::Decimal128(_, _), Decimal128Type),
(ArrowDataType::Decimal256(_, _), Decimal256Type)
}
}
(Literal(lit), Literal(Scalar::Array(ad))) => {
#[allow(deprecated)]
let exists = ad.array_elements().contains(lit);
Ok(Arc::new(BooleanArray::from(vec![exists])))
}
(l, r) => Err(Error::invalid_expression(format!(
"Invalid right value for (NOT) IN comparison, left is: {l} right is: {r}"
))),
},
(
BinaryOperation {
op: NotIn,
left,
right,
},
_,
) => {
let reverse_op = Expression::binary(In, *left.clone(), *right.clone());
let reverse_expr = evaluate_expression(&reverse_op, batch, None)?;
not(reverse_expr.as_boolean())
.map(wrap_comparison_result)
.map_err(Error::generic_err)
}
(BinaryOperation { op, left, right }, _) => {
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;
Expand All @@ -233,6 +335,7 @@ fn evaluate_expression(
Equal => |l, r| eq(l, r).map(wrap_comparison_result),
NotEqual => |l, r| neq(l, r).map(wrap_comparison_result),
Distinct => |l, r| distinct(l, r).map(wrap_comparison_result),
_ => return Err(Error::generic("Invalid expression given")),
};

eval(&left_arr, &right_arr).map_err(Error::generic_err)
Expand Down Expand Up @@ -321,11 +424,151 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {

#[cfg(test)]
mod tests {
use std::ops::{Add, Div, Mul, Sub};

use super::*;
use arrow_array::Int32Array;
use arrow_array::{GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};
use std::ops::{Add, Div, Mul, Sub};

use super::*;
use crate::expressions::*;
use crate::schema::ArrayType;
use crate::DataType as DeltaDataTypes;

#[test]
fn test_array_column() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Int32, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));

let schema = Schema::new(vec![arr_field.clone()]);

let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
assert_eq!(result.as_ref(), &expected);

let in_result = evaluate_expression(&in_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![true, false, true]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_bad_right_type_array() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let field = Arc::new(Field::new("item", DataType::Int32, true));
let schema = Schema::new(vec![field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);

assert!(in_result.is_err());
assert_eq!(
in_result.unwrap_err().to_string(),
"Invalid expression evaluation: Right side column: item is not a list or a fixed size list".to_string()
)
}

#[test]
fn test_literal_type_array() {
let field = Arc::new(Field::new("item", DataType::Int32, true));
let schema = Schema::new(vec![field.clone()]);
let batch = RecordBatch::new_empty(Arc::new(schema));

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(DeltaDataTypes::Primitive(PrimitiveType::Integer), false),
vec![Scalar::Integer(1), Scalar::Integer(2)],
))),
);

let in_result = evaluate_expression(&in_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![true]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_invalid_array_sides() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Int32, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));

let schema = Schema::new(vec![arr_field.clone()]);

let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::column("item"),
Expression::column("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);

assert!(in_result.is_err());
assert_eq!(
in_result.unwrap_err().to_string(),
"Invalid expression evaluation: Invalid right value for (NOT) IN comparison, left is: Column(item) right is: Column(item)".to_string()
)
}

#[test]
fn test_str_arrays() {
let values = GenericStringArray::<i32>::from(vec![
"hi", "bye", "hi", "hi", "bye", "bye", "hi", "bye", "hi",
]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Utf8, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));
let schema = Schema::new(vec![arr_field.clone()]);
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal("bye"),
Expression::column("item"),
);

let str_in_op = Expression::binary(
BinaryOperator::In,
Expression::literal("hi"),
Expression::column("item"),
);

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
assert_eq!(result.as_ref(), &expected);

let in_result = evaluate_expression(&str_not_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![false, false, false]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_extract_column() {
Expand Down
34 changes: 34 additions & 0 deletions kernel/src/engine/arrow_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,39 @@ use itertools::Itertools;
use parquet::{arrow::ProjectionMask, schema::types::SchemaDescriptor};
use tracing::debug;

macro_rules! prim_array_cmp {
( $left_arr: ident, $right_arr: ident, $(($data_ty: pat, $prim_ty: ty)),+ ) => {

return match $left_arr.data_type() {
$(
$data_ty => {
let prim_array = $left_arr.as_primitive_opt::<$prim_ty>()
.ok_or(Error::invalid_expression(
format!("Cannot cast to primitive array: {}", $left_arr.data_type()))
)?;
let list_array = $right_arr.as_list_opt::<i32>()
.ok_or(Error::invalid_expression(
format!("Cannot cast to list array: {}", $right_arr.data_type()))
)?;
arrow_ord::comparison::in_list(prim_array, list_array).map(wrap_comparison_result)
}
)+
_ => Err(ArrowError::CastError(
format!("Bad Comparison between: {:?} and {:?}",
$left_arr.data_type(),
$right_arr.data_type())
)
)
}.map_err(Error::generic_err);
};
}

pub(crate) use prim_array_cmp;

/// Get the indicies in `parquet_schema` of the specified columns in `requested_schema`. This
/// returns a tuples of (mask_indicies: Vec<parquet_schema_index>, reorder_indicies:
/// Vec<requested_index>). `mask_indicies` is used for generating the mask for reading from the
fn make_arrow_error(s: String) -> Error {
Error::Arrow(arrow_schema::ArrowError::InvalidArgumentError(s))
}
Expand Down Expand Up @@ -498,6 +531,7 @@ fn get_indices(
/// Get the indices in `parquet_schema` of the specified columns in `requested_schema`. This returns
/// a tuple of (mask_indices: Vec<parquet_schema_index>, reorder_indices:
/// Vec<requested_index>). `mask_indices` is used for generating the mask for reading from the
/// parquet file, and simply contains an entry for each index we wish to select from the parquet
/// file set to the index of the requested column in the parquet. `reorder_indices` is used for
/// re-ordering. See the documentation for [`ReorderIndex`] to understand what each element in the
Expand Down
Loading

0 comments on commit 1459b30

Please sign in to comment.