Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass on IN/Not In #270

Merged
merged 29 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
64f7a52
First pass on IN/Not In
hntd187 Jun 29, 2024
75a6a3a
In/Not In support for String Arrays and more tests
hntd187 Jun 29, 2024
7cb6c45
chore: fmt
hntd187 Jun 29, 2024
5be298c
Update kernel/src/engine/arrow_expression.rs
hntd187 Jul 1, 2024
0ddcbdf
Merge branch 'main' into in-not-in
hntd187 Jul 1, 2024
4a715b2
Address PR feedback
hntd187 Jul 1, 2024
bbcbc46
Address PR feedback
hntd187 Jul 1, 2024
ae17592
Address PR feedback
hntd187 Jul 5, 2024
9de1edc
Merge branch 'main' into in-not-in
hntd187 Jul 20, 2024
d7ca9d8
Address PR feedback
hntd187 Jul 20, 2024
a03b4fe
Address PR feedback, use dangling pointer for init/empty array instea…
hntd187 Jul 25, 2024
36d7b9a
Merge branch 'main' into in-not-in
hntd187 Aug 2, 2024
9abd4bc
Merge branch 'main' into in-not-in
hntd187 Aug 3, 2024
ff28e09
Merge remote-tracking branch 'mine/in-not-in' into in-not-in
hntd187 Aug 4, 2024
0309334
Address PR feedback, as well as resolve some lints and nightly build …
hntd187 Aug 4, 2024
ad9ca08
Fix failing test in ffi
hntd187 Aug 4, 2024
e2eda49
Updated a test to remove arrow deps
hntd187 Aug 5, 2024
52bfc0b
Merge branch 'main' into in-not-in
hntd187 Aug 5, 2024
5c968d9
Merge remote-tracking branch 'origin/main' into in-not-in
hntd187 Aug 12, 2024
e1fd8f6
Merge branch 'main' into in-not-in
hntd187 Aug 12, 2024
b12f8d6
Merge remote-tracking branch 'mine/in-not-in' into in-not-in
hntd187 Aug 12, 2024
1c079d9
Added a guard and a test for when the right side column comparison fo…
hntd187 Aug 12, 2024
0866f35
Added a test for literal array comparisons
hntd187 Aug 12, 2024
112fdf7
chore: fmt
hntd187 Aug 12, 2024
ec07f07
added a test for invalid cases and tighten a match arm to not match o…
hntd187 Aug 12, 2024
e752ae3
added a test for invalid cases and tighten a match arm to not match o…
hntd187 Aug 12, 2024
677581a
Merge remote-tracking branch 'mine/in-not-in' into in-not-in
hntd187 Aug 12, 2024
444b1dc
chore:fmt
hntd187 Aug 12, 2024
a69a95b
Merge branch 'main' into in-not-in
hntd187 Aug 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ version = "0.3.0"
arrow = { version = "^52.0" }
arrow-arith = { version = "^52.0" }
arrow-array = { version = "^52.0" }
arrow-buffer = { version = "^52.0" }
arrow-cast = { version = "^52.0" }
arrow-data = { version = "^52.0" }
arrow-ord = { version = "^52.0" }
Expand Down
2 changes: 2 additions & 0 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ pub enum KernelError {
InvalidDecimalError,
InvalidStructDataError,
InternalError,
InvalidExpression,
}

impl From<Error> for KernelError {
Expand Down Expand Up @@ -372,6 +373,7 @@ impl From<Error> for KernelError {
source,
backtrace: _,
} => Self::from(*source),
Error::InvalidExpressionEvaluation(_) => KernelError::InvalidExpression,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ delta_kernel_derive = { path = "../derive-macros", version = "0.3.0" }
visibility = "0.1.0"

# Used in default engine
arrow-buffer = { workspace = true, optional = true }
arrow-array = { workspace = true, optional = true, features = ["chrono-tz"] }
arrow-select = { workspace = true, optional = true }
arrow-arith = { workspace = true, optional = true }
Expand Down Expand Up @@ -78,6 +79,7 @@ default-engine = [
"arrow-conversion",
"arrow-expression",
"arrow-array",
"arrow-buffer",
"arrow-cast",
"arrow-json",
"arrow-schema",
Expand Down
253 changes: 248 additions & 5 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@ use std::sync::Arc;
use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene};
use arrow_arith::numeric::{add, div, mul, sub};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch,
StringArray, StructArray, TimestampMicrosecondArray,
};
use arrow_buffer::OffsetBuffer;
use arrow_ord::cmp::{distinct, eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_ord::comparison::in_list_utf8;
use arrow_schema::{
ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, Schema as ArrowSchema,
ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, IntervalUnit,
Schema as ArrowSchema, TimeUnit,
};
use arrow_select::concat::concat;
use itertools::Itertools;

use super::arrow_conversion::LIST_ARRAY_ROOT;
use crate::engine::arrow_data::ArrowEngineData;
use crate::engine::arrow_utils::ensure_data_types;
use crate::engine::arrow_utils::prim_array_cmp;
use crate::error::{DeltaResult, Error};
use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator};
use crate::schema::{DataType, PrimitiveType, SchemaRef};
Expand Down Expand Up @@ -67,6 +73,21 @@ impl Scalar {
.try_collect()?;
Arc::new(StructArray::try_new(fields, arrays, None)?)
}
Array(data) => {
#[allow(deprecated)]
let values = data.array_elements();
let vecs: Vec<_> = values.iter().map(|v| v.to_array(num_rows)).try_collect()?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% confident I understand what this code is doing but:

It looks like we need to replicate an "array scalar" value into an arrow array suitable for comparing against a batch of primitive values, right? So e.g. if I had x in (1, 2, 3) then values would be [1, 2, 3] and we need to explode that into [[1, 2, 3], [1, 2, 3], ..., [1, 2, 3]] so that the eventual evaluate_expression call can invoke e.g. in_list? The documentation for that function is sorely incomplete, but I guess it's marching element by element through two arrays, producing a true output element each row whose list array element contains the corresponding primitive array element? That would be general enough to handle a correlated subquery, but would be quite space-inefficient for the common case (literal in-list or uncorrelated subquery) where we compare against the same array in every row -- especially in case said array is large.

Does arrow rust provide a "scalar" version of in_list that takes two primitive arrays instead? So that e.g. scalar_in_list([1, 2, 3, 4, 5, 6], [1, 3, 5]) returns [true, false, true, false, true, false]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, go through and build the offsets. This approach was actually heavily inspired from how datafusion does this, so while I agree a very large list likely doesn't have great performance there isn't (to my current knowledge) a more idiomatic way to accomplish this. I generally dislike working with the list types in arrow, but to my knowledge arrow does not provide a scalar version of it. You would have to replicate the static value N number of times for each row. Primitive array, Generic string array on the left and list array on the right.

Copy link
Collaborator

@scovich scovich Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we should really build+probe a hash table for all but the smallest in-lists, to avoid paying O(n**2) work. It would likely pay for itself for any query that has more than about 100 rows (and smaller queries would run so fast who cares).

If arrow doesn't give a way to do that maybe we need to do it ourselves?

Copy link
Collaborator

@scovich scovich Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: the hash table thing gets pretty clearly into engine territory, so if we can figure out how to offload the optimization to engine [data], that's probably better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: the hash table thing gets pretty clearly into engine territory, so if we can figure out how to offload the optimization to engine [data], that's probably better.

We're already in engine territory here, since this is the expression evaluator. That said, I think we could make this better. See my comment below where we evaluate IN

let values: Vec<_> = vecs.iter().map(|x| x.as_ref()).collect();
let offsets: Vec<_> = vecs.iter().map(|v| v.len()).collect();
let offset_buffer = OffsetBuffer::from_lengths(offsets);
let field = ArrowField::try_from(data.array_type())?;
Arc::new(ListArray::new(
Arc::new(field),
offset_buffer,
concat(values.as_slice())?,
None,
))
}
Null(data_type) => match data_type {
DataType::Primitive(primitive) => match primitive {
PrimitiveType::Byte => Arc::new(Int8Array::new_null(num_rows)),
Expand Down Expand Up @@ -168,7 +189,6 @@ fn evaluate_expression(
) -> DeltaResult<ArrayRef> {
use BinaryOperator::*;
use Expression::*;

match (expression, result_type) {
(Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?),
(Column(name), _) => {
Expand Down Expand Up @@ -216,6 +236,88 @@ fn evaluate_expression(
UnaryOperator::IsNull => Arc::new(is_null(&arr)?),
})
}
(
BinaryOperation {
op: In,
left,
right,
},
_,
) => match (left.as_ref(), right.as_ref()) {
(Literal(_), Column(c)) => {
let list_type = batch.column_by_name(c).map(|c| c.data_type());
if !matches!(
list_type,
Some(ArrowDataType::List(_)) | Some(ArrowDataType::FixedSizeList(_, _))
) {
return Err(Error::InvalidExpressionEvaluation(format!(
"Right side column: {c} is not a list or a fixed size list"
)));
}
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;
if let Some(string_arr) = left_arr.as_string_opt::<i32>() {
return in_list_utf8(string_arr, right_arr.as_list::<i32>())
scovich marked this conversation as resolved.
Show resolved Hide resolved
.map(wrap_comparison_result)
.map_err(Error::generic_err);
}
prim_array_cmp! {
left_arr, right_arr,
(ArrowDataType::Int8, Int8Type),
(ArrowDataType::Int16, Int16Type),
(ArrowDataType::Int32, Int32Type),
(ArrowDataType::Int64, Int64Type),
(ArrowDataType::UInt8, UInt8Type),
(ArrowDataType::UInt16, UInt16Type),
(ArrowDataType::UInt32, UInt32Type),
(ArrowDataType::UInt64, UInt64Type),
(ArrowDataType::Float16, Float16Type),
(ArrowDataType::Float32, Float32Type),
(ArrowDataType::Float64, Float64Type),
(ArrowDataType::Timestamp(TimeUnit::Second, _), TimestampSecondType),
(ArrowDataType::Timestamp(TimeUnit::Millisecond, _), TimestampMillisecondType),
(ArrowDataType::Timestamp(TimeUnit::Microsecond, _), TimestampMicrosecondType),
(ArrowDataType::Timestamp(TimeUnit::Nanosecond, _), TimestampNanosecondType),
(ArrowDataType::Date32, Date32Type),
(ArrowDataType::Date64, Date64Type),
(ArrowDataType::Time32(TimeUnit::Second), Time32SecondType),
(ArrowDataType::Time32(TimeUnit::Millisecond), Time32MillisecondType),
(ArrowDataType::Time64(TimeUnit::Microsecond), Time64MicrosecondType),
(ArrowDataType::Time64(TimeUnit::Nanosecond), Time64NanosecondType),
(ArrowDataType::Duration(TimeUnit::Second), DurationSecondType),
(ArrowDataType::Duration(TimeUnit::Millisecond), DurationMillisecondType),
(ArrowDataType::Duration(TimeUnit::Microsecond), DurationMicrosecondType),
(ArrowDataType::Duration(TimeUnit::Nanosecond), DurationNanosecondType),
(ArrowDataType::Interval(IntervalUnit::DayTime), IntervalDayTimeType),
(ArrowDataType::Interval(IntervalUnit::YearMonth), IntervalYearMonthType),
(ArrowDataType::Interval(IntervalUnit::MonthDayNano), IntervalMonthDayNanoType),
(ArrowDataType::Decimal128(_, _), Decimal128Type),
(ArrowDataType::Decimal256(_, _), Decimal256Type)
}
}
(Literal(lit), Literal(Scalar::Array(ad))) => {
#[allow(deprecated)]
let exists = ad.array_elements().contains(lit);
Ok(Arc::new(BooleanArray::from(vec![exists])))
}
(l, r) => Err(Error::invalid_expression(format!(
"Invalid right value for (NOT) IN comparison, left is: {l} right is: {r}"
))),
},
(
BinaryOperation {
op: NotIn,
left,
right,
},
_,
) => {
let reverse_op = Expression::binary(In, *left.clone(), *right.clone());
let reverse_expr = evaluate_expression(&reverse_op, batch, None)?;
not(reverse_expr.as_boolean())
.map(wrap_comparison_result)
.map_err(Error::generic_err)
}
(BinaryOperation { op, left, right }, _) => {
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;
Expand All @@ -233,6 +335,7 @@ fn evaluate_expression(
Equal => |l, r| eq(l, r).map(wrap_comparison_result),
NotEqual => |l, r| neq(l, r).map(wrap_comparison_result),
Distinct => |l, r| distinct(l, r).map(wrap_comparison_result),
_ => return Err(Error::generic("Invalid expression given")),
};

eval(&left_arr, &right_arr).map_err(Error::generic_err)
Expand Down Expand Up @@ -321,11 +424,151 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {

#[cfg(test)]
mod tests {
use std::ops::{Add, Div, Mul, Sub};

use super::*;
use arrow_array::Int32Array;
use arrow_array::{GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};
use std::ops::{Add, Div, Mul, Sub};

use super::*;
use crate::expressions::*;
use crate::schema::ArrayType;
use crate::DataType as DeltaDataTypes;

#[test]
fn test_array_column() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Int32, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));

let schema = Schema::new(vec![arr_field.clone()]);

let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
assert_eq!(result.as_ref(), &expected);

let in_result = evaluate_expression(&in_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![true, false, true]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_bad_right_type_array() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let field = Arc::new(Field::new("item", DataType::Int32, true));
let schema = Schema::new(vec![field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);

assert!(in_result.is_err());
assert_eq!(
in_result.unwrap_err().to_string(),
"Invalid expression evaluation: Right side column: item is not a list or a fixed size list".to_string()
)
}

#[test]
fn test_literal_type_array() {
let field = Arc::new(Field::new("item", DataType::Int32, true));
let schema = Schema::new(vec![field.clone()]);
let batch = RecordBatch::new_empty(Arc::new(schema));

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(DeltaDataTypes::Primitive(PrimitiveType::Integer), false),
vec![Scalar::Integer(1), Scalar::Integer(2)],
))),
);

let in_result = evaluate_expression(&in_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![true]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_invalid_array_sides() {
let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Int32, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));

let schema = Schema::new(vec![arr_field.clone()]);

let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::column("item"),
Expression::column("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);

assert!(in_result.is_err());
assert_eq!(
in_result.unwrap_err().to_string(),
"Invalid expression evaluation: Invalid right value for (NOT) IN comparison, left is: Column(item) right is: Column(item)".to_string()
)
}

#[test]
fn test_str_arrays() {
let values = GenericStringArray::<i32>::from(vec![
"hi", "bye", "hi", "hi", "bye", "bye", "hi", "bye", "hi",
]);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9]));
let field = Arc::new(Field::new("item", DataType::Utf8, true));
let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true));
let schema = Schema::new(vec![arr_field.clone()]);
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal("bye"),
Expression::column("item"),
);

let str_in_op = Expression::binary(
BinaryOperator::In,
Expression::literal("hi"),
Expression::column("item"),
);

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
assert_eq!(result.as_ref(), &expected);

let in_result = evaluate_expression(&str_not_op, &batch, None).unwrap();
let in_expected = BooleanArray::from(vec![false, false, false]);
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_extract_column() {
Expand Down
34 changes: 34 additions & 0 deletions kernel/src/engine/arrow_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,39 @@ use itertools::Itertools;
use parquet::{arrow::ProjectionMask, schema::types::SchemaDescriptor};
use tracing::debug;

macro_rules! prim_array_cmp {
( $left_arr: ident, $right_arr: ident, $(($data_ty: pat, $prim_ty: ty)),+ ) => {

return match $left_arr.data_type() {
$(
$data_ty => {
let prim_array = $left_arr.as_primitive_opt::<$prim_ty>()
.ok_or(Error::invalid_expression(
format!("Cannot cast to primitive array: {}", $left_arr.data_type()))
)?;
let list_array = $right_arr.as_list_opt::<i32>()
hntd187 marked this conversation as resolved.
Show resolved Hide resolved
.ok_or(Error::invalid_expression(
format!("Cannot cast to list array: {}", $right_arr.data_type()))
)?;
arrow_ord::comparison::in_list(prim_array, list_array).map(wrap_comparison_result)
}
)+
_ => Err(ArrowError::CastError(
format!("Bad Comparison between: {:?} and {:?}",
$left_arr.data_type(),
$right_arr.data_type())
)
)
}.map_err(Error::generic_err);
};
}

pub(crate) use prim_array_cmp;

/// Get the indicies in `parquet_schema` of the specified columns in `requested_schema`. This
/// returns a tuples of (mask_indicies: Vec<parquet_schema_index>, reorder_indicies:
/// Vec<requested_index>). `mask_indicies` is used for generating the mask for reading from the

fn make_arrow_error(s: String) -> Error {
Error::Arrow(arrow_schema::ArrowError::InvalidArgumentError(s))
}
Expand Down Expand Up @@ -498,6 +531,7 @@ fn get_indices(
/// Get the indices in `parquet_schema` of the specified columns in `requested_schema`. This returns
/// a tuple of (mask_indices: Vec<parquet_schema_index>, reorder_indices:
/// Vec<requested_index>). `mask_indices` is used for generating the mask for reading from the

/// parquet file, and simply contains an entry for each index we wish to select from the parquet
/// file set to the index of the requested column in the parquet. `reorder_indices` is used for
/// re-ordering. See the documentation for [`ReorderIndex`] to understand what each element in the
Expand Down
Loading
Loading