-
Couldn't load subscription status.
- Fork 118
First pass on IN/Not In #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
64f7a52
75a6a3a
7cb6c45
5be298c
0ddcbdf
4a715b2
bbcbc46
ae17592
9de1edc
d7ca9d8
a03b4fe
36d7b9a
9abd4bc
ff28e09
0309334
ad9ca08
e2eda49
52bfc0b
5c968d9
e1fd8f6
b12f8d6
1c079d9
0866f35
112fdf7
ec07f07
e752ae3
677581a
444b1dc
a69a95b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,25 +4,32 @@ use std::sync::Arc; | |
| use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene}; | ||
| use arrow_arith::numeric::{add, div, mul, sub}; | ||
| use arrow_array::cast::AsArray; | ||
| use arrow_array::types::*; | ||
| use arrow_array::{ | ||
| Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array, | ||
| Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, | ||
| StringArray, StructArray, TimestampMicrosecondArray, | ||
| }; | ||
| use arrow_buffer::OffsetBuffer; | ||
| use arrow_ord::cmp::{distinct, eq, gt, gt_eq, lt, lt_eq, neq}; | ||
| use arrow_ord::comparison::{in_list, in_list_utf8}; | ||
| use arrow_schema::{ | ||
| ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, Schema as ArrowSchema, | ||
| ArrowError, DataType as ArrowDataType, Field as ArrowField, Fields, IntervalUnit, | ||
| Schema as ArrowSchema, TimeUnit, | ||
| }; | ||
| use arrow_select::concat::concat; | ||
| use itertools::Itertools; | ||
|
|
||
| use super::arrow_conversion::LIST_ARRAY_ROOT; | ||
| use crate::engine::arrow_data::ArrowEngineData; | ||
| use crate::engine::arrow_utils::prim_array_cmp; | ||
| use crate::engine::arrow_utils::ensure_data_types; | ||
| use crate::error::{DeltaResult, Error}; | ||
| use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator}; | ||
| use crate::schema::{DataType, PrimitiveType, SchemaRef}; | ||
| use crate::{EngineData, ExpressionEvaluator, ExpressionHandler}; | ||
|
|
||
| use super::arrow_conversion::LIST_ARRAY_ROOT; | ||
|
|
||
| // TODO leverage scalars / Datum | ||
|
|
||
| fn downcast_to_bool(arr: &dyn Array) -> DeltaResult<&BooleanArray> { | ||
|
|
@@ -67,6 +74,20 @@ impl Scalar { | |
| .try_collect()?; | ||
| Arc::new(StructArray::try_new(fields, arrays, None)?) | ||
| } | ||
| Array(data) => { | ||
| let values = data.array_elements(); | ||
| let vecs: Vec<_> = values.iter().map(|v| v.to_array(num_rows)).try_collect()?; | ||
| let values: Vec<_> = vecs.iter().map(|x| x.as_ref()).collect(); | ||
| let offsets: Vec<_> = vecs.iter().map(|v| v.len()).collect(); | ||
| let offset_buffer = OffsetBuffer::from_lengths(offsets); | ||
| let field = ArrowField::try_from(data.array_type())?; | ||
| Arc::new(ListArray::new( | ||
| Arc::new(field), | ||
| offset_buffer, | ||
| concat(values.as_slice())?, | ||
| None, | ||
| )) | ||
| } | ||
| Null(data_type) => match data_type { | ||
| DataType::Primitive(primitive) => match primitive { | ||
| PrimitiveType::Byte => Arc::new(Int8Array::new_null(num_rows)), | ||
|
|
@@ -168,7 +189,6 @@ fn evaluate_expression( | |
| ) -> DeltaResult<ArrayRef> { | ||
| use BinaryOperator::*; | ||
| use Expression::*; | ||
|
|
||
| match (expression, result_type) { | ||
| (Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?), | ||
| (Column(name), _) => { | ||
|
|
@@ -216,6 +236,69 @@ fn evaluate_expression( | |
| UnaryOperator::IsNull => Arc::new(is_null(&arr)?), | ||
| }) | ||
| } | ||
| ( | ||
| BinaryOperation { | ||
| op: In, | ||
| left, | ||
| right, | ||
| }, | ||
| _, | ||
| ) => { | ||
| let left_arr = evaluate_expression(left.as_ref(), batch, None)?; | ||
|
||
| let right_arr = evaluate_expression(right.as_ref(), batch, None)?; | ||
|
||
| if let Some(string_arr) = left_arr.as_string_opt::<i32>() { | ||
| return in_list_utf8(string_arr, right_arr.as_list::<i32>()) | ||
| .map(wrap_comparison_result) | ||
| .map_err(Error::generic_err); | ||
| } | ||
| prim_array_cmp! { | ||
| left_arr, right_arr, | ||
| (ArrowDataType::Int8, Int8Type), | ||
|
||
| (ArrowDataType::Int16, Int16Type), | ||
| (ArrowDataType::Int32, Int32Type), | ||
| (ArrowDataType::Int64, Int64Type), | ||
| (ArrowDataType::UInt8, UInt8Type), | ||
| (ArrowDataType::UInt16, UInt16Type), | ||
| (ArrowDataType::UInt32, UInt32Type), | ||
| (ArrowDataType::UInt64, UInt64Type), | ||
| (ArrowDataType::Float16, Float16Type), | ||
| (ArrowDataType::Float32, Float32Type), | ||
| (ArrowDataType::Float64, Float64Type), | ||
| (ArrowDataType::Timestamp(TimeUnit::Second, _), TimestampSecondType), | ||
| (ArrowDataType::Timestamp(TimeUnit::Millisecond, _), TimestampMillisecondType), | ||
| (ArrowDataType::Timestamp(TimeUnit::Microsecond, _), TimestampMicrosecondType), | ||
| (ArrowDataType::Timestamp(TimeUnit::Nanosecond, _), TimestampNanosecondType), | ||
| (ArrowDataType::Date32, Date32Type), | ||
| (ArrowDataType::Date64, Date64Type), | ||
| (ArrowDataType::Time32(TimeUnit::Second), Time32SecondType), | ||
| (ArrowDataType::Time32(TimeUnit::Millisecond), Time32MillisecondType), | ||
| (ArrowDataType::Time64(TimeUnit::Microsecond), Time64MicrosecondType), | ||
| (ArrowDataType::Time64(TimeUnit::Nanosecond), Time64NanosecondType), | ||
| (ArrowDataType::Duration(TimeUnit::Second), DurationSecondType), | ||
| (ArrowDataType::Duration(TimeUnit::Millisecond), DurationMillisecondType), | ||
| (ArrowDataType::Duration(TimeUnit::Microsecond), DurationMicrosecondType), | ||
| (ArrowDataType::Duration(TimeUnit::Nanosecond), DurationNanosecondType), | ||
| (ArrowDataType::Interval(IntervalUnit::DayTime), IntervalDayTimeType), | ||
| (ArrowDataType::Interval(IntervalUnit::YearMonth), IntervalYearMonthType), | ||
| (ArrowDataType::Interval(IntervalUnit::MonthDayNano), IntervalMonthDayNanoType), | ||
| (ArrowDataType::Decimal128(_, _), Decimal128Type), | ||
| (ArrowDataType::Decimal256(_, _), Decimal256Type) | ||
| } | ||
| } | ||
| ( | ||
| BinaryOperation { | ||
| op: NotIn, | ||
| left, | ||
| right, | ||
| }, | ||
| _, | ||
| ) => { | ||
| let reverse_op = Expression::binary(In, *left.clone(), *right.clone()); | ||
| let reverse_expr = evaluate_expression(&reverse_op, batch, None)?; | ||
| not(reverse_expr.as_boolean()) | ||
| .map(wrap_comparison_result) | ||
| .map_err(Error::generic_err) | ||
| } | ||
| (BinaryOperation { op, left, right }, _) => { | ||
| let left_arr = evaluate_expression(left.as_ref(), batch, None)?; | ||
| let right_arr = evaluate_expression(right.as_ref(), batch, None)?; | ||
|
|
@@ -233,6 +316,7 @@ fn evaluate_expression( | |
| Equal => |l, r| eq(l, r).map(wrap_comparison_result), | ||
| NotEqual => |l, r| neq(l, r).map(wrap_comparison_result), | ||
| Distinct => |l, r| distinct(l, r).map(wrap_comparison_result), | ||
| _ => return Err(Error::generic("Invalid expression given")), | ||
| }; | ||
|
|
||
| eval(&left_arr, &right_arr).map_err(Error::generic_err) | ||
|
|
@@ -321,11 +405,81 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator { | |
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::ops::{Add, Div, Mul, Sub}; | ||
|
|
||
| use super::*; | ||
| use arrow_array::Int32Array; | ||
| use arrow_array::{GenericStringArray, Int32Array}; | ||
| use arrow_buffer::ScalarBuffer; | ||
| use arrow_schema::{DataType, Field, Fields, Schema}; | ||
| use std::ops::{Add, Div, Mul, Sub}; | ||
|
|
||
| use crate::expressions::*; | ||
|
|
||
| use super::*; | ||
|
|
||
| #[test] | ||
| fn test_array_column() { | ||
| let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]); | ||
| let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9])); | ||
| let field = Arc::new(Field::new("item", DataType::Int32, true)); | ||
| let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true)); | ||
|
|
||
| let schema = Schema::new(vec![arr_field.clone()]); | ||
|
|
||
| let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); | ||
| let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); | ||
|
|
||
| let not_op = Expression::binary( | ||
| BinaryOperator::NotIn, | ||
| Expression::literal(5), | ||
| Expression::column("item"), | ||
| ); | ||
|
|
||
| let in_op = Expression::binary( | ||
| BinaryOperator::NotIn, | ||
| Expression::literal(5), | ||
| Expression::column("item"), | ||
| ); | ||
|
|
||
| let result = evaluate_expression(¬_op, &batch, None).unwrap(); | ||
| let expected = BooleanArray::from(vec![true, false, true]); | ||
| assert_eq!(result.as_ref(), &expected); | ||
|
|
||
| let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); | ||
| let in_expected = BooleanArray::from(vec![true, false, true]); | ||
| assert_eq!(in_result.as_ref(), &in_expected); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_str_arrays() { | ||
| let values = GenericStringArray::<i32>::from(vec![ | ||
| "hi", "bye", "hi", "hi", "bye", "bye", "hi", "bye", "hi", | ||
| ]); | ||
| let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 9])); | ||
| let field = Arc::new(Field::new("item", DataType::Utf8, true)); | ||
| let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true)); | ||
| let schema = Schema::new(vec![arr_field.clone()]); | ||
| let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); | ||
| let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); | ||
|
|
||
| let str_not_op = Expression::binary( | ||
| BinaryOperator::NotIn, | ||
| Expression::literal("bye"), | ||
| Expression::column("item"), | ||
| ); | ||
|
|
||
| let str_in_op = Expression::binary( | ||
| BinaryOperator::In, | ||
| Expression::literal("hi"), | ||
| Expression::column("item"), | ||
| ); | ||
|
|
||
| let result = evaluate_expression(&str_in_op, &batch, None).unwrap(); | ||
| let expected = BooleanArray::from(vec![true, true, true]); | ||
| assert_eq!(result.as_ref(), &expected); | ||
|
|
||
| let in_result = evaluate_expression(&str_not_op, &batch, None).unwrap(); | ||
| let in_expected = BooleanArray::from(vec![false, false, false]); | ||
| assert_eq!(in_result.as_ref(), &in_expected); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_extract_column() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% confident I understand what this code is doing but:
It looks like we need to replicate an "array scalar" value into an arrow array suitable for comparing against a batch of primitive values, right? So e.g. if I had
x in (1, 2, 3)thenvalueswould be[1, 2, 3]and we need to explode that into[[1, 2, 3], [1, 2, 3], ..., [1, 2, 3]]so that the eventualevaluate_expressioncall can invoke e.g. in_list? The documentation for that function is sorely incomplete, but I guess it's marching element by element through two arrays, producing a true output element each row whose list array element contains the corresponding primitive array element? That would be general enough to handle a correlated subquery, but would be quite space-inefficient for the common case (literal in-list or uncorrelated subquery) where we compare against the same array in every row -- especially in case said array is large.Does arrow rust provide a "scalar" version of
in_listthat takes two primitive arrays instead? So that e.g.scalar_in_list([1, 2, 3, 4, 5, 6], [1, 3, 5])returns[true, false, true, false, true, false]?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct, go through and build the offsets. This approach was actually heavily inspired from how datafusion does this, so while I agree a very large list likely doesn't have great performance there isn't (to my current knowledge) a more idiomatic way to accomplish this. I generally dislike working with the list types in arrow, but to my knowledge arrow does not provide a scalar version of it. You would have to replicate the static value N number of times for each row. Primitive array, Generic string array on the left and list array on the right.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we should really build+probe a hash table for all but the smallest in-lists, to avoid paying
O(n**2)work. It would likely pay for itself for any query that has more than about 100 rows (and smaller queries would run so fast who cares).If arrow doesn't give a way to do that maybe we need to do it ourselves?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: the hash table thing gets pretty clearly into engine territory, so if we can figure out how to offload the optimization to engine [data], that's probably better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're already in engine territory here, since this is the expression evaluator. That said, I think we could make this better. See my comment below where we evaluate
IN