Skip to content

Commit a3e618e

Browse files
nanderstabelTao Wu
and
Tao Wu
authored
feat(expr): add array access support (#2883)
* Implement array access support for arrays * Add comment * Add array_access test * Add value_at method * Delete ArrayAccessExpression * Add array access * Add array_access * Expand value_at and add test * Make array_access a BinaryNullableExpression * Add test with negative index * Refactor array_access * Rewrite test for array access as nullable binary expression * Fix return_type * Clean * Clean * Add nested test and cargo check * Add nl * Fix index out of bounds error * Add test * Add ArrayAccess * Update src/expr/src/expr/expr_binary_nullable.rs returning DataType::List is not implemented yet. Co-authored-by: Tao Wu <[email protected]> * Add comma * Fix test * Delete unnecessary 'vec' * Fix cargo.lock file Co-authored-by: Tao Wu <[email protected]>
1 parent cab50f1 commit a3e618e

File tree

10 files changed

+320
-5
lines changed

10 files changed

+320
-5
lines changed

e2e_test/batch/array_access.slt

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
statement ok
2+
SET RW_IMPLICIT_FLUSH TO true;
3+
4+
query T
5+
select (ARRAY['foo', 'bar'])[-1];
6+
----
7+
NULL
8+
9+
query T
10+
select (ARRAY['foo', 'bar'])[0];
11+
----
12+
NULL
13+
14+
query T
15+
select (ARRAY['foo', 'bar'])[1];
16+
----
17+
foo
18+
19+
query T
20+
select (ARRAY['foo', 'bar'])[3];
21+
----
22+
NULL
23+
24+
statement error
25+
select (ARRAY['foo', 'bar'])[];

src/common/src/array/list_array.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ impl ListArray {
239239
Ok(arr.into())
240240
}
241241

242-
#[cfg(test)]
242+
// Used for testing purposes
243243
pub fn from_slices(
244244
null_bitmap: &[bool],
245245
values: Vec<Option<ArrayImpl>>,
@@ -352,6 +352,25 @@ impl<'a> ListRef<'a> {
352352
ListRef::ValueRef { val } => val.values.iter().map(to_datum_ref).collect(),
353353
}
354354
}
355+
356+
pub fn value_at(&self, index: usize) -> Result<DatumRef<'a>> {
357+
match self {
358+
ListRef::Indexed { arr, .. } => {
359+
if index <= arr.value.len() {
360+
Ok(arr.value.value_at(index - 1))
361+
} else {
362+
Ok(None)
363+
}
364+
}
365+
ListRef::ValueRef { val } => {
366+
if let Some(datum) = val.values().iter().nth(index - 1) {
367+
Ok(to_datum_ref(datum))
368+
} else {
369+
Ok(None)
370+
}
371+
}
372+
}
373+
}
355374
}
356375

357376
impl Hash for ListRef<'_> {

src/expr/src/expr/build_expr_from_prost.rs

+71
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,84 @@ pub fn build_to_char_expr(prost: &ExprNode) -> Result<BoxedExpression> {
239239
mod tests {
240240
use std::vec;
241241

242+
use risingwave_common::array::{ArrayImpl, Utf8Array};
242243
use risingwave_pb::data::data_type::TypeName;
243244
use risingwave_pb::data::DataType as ProstDataType;
244245
use risingwave_pb::expr::expr_node::{RexNode, Type};
245246
use risingwave_pb::expr::{ConstantValue, ExprNode, FunctionCall, InputRefExpr};
246247

247248
use super::*;
248249

250+
#[test]
251+
fn test_array_access_expr() {
252+
let values = FunctionCall {
253+
children: vec![
254+
ExprNode {
255+
expr_type: Type::ConstantValue as i32,
256+
return_type: Some(ProstDataType {
257+
type_name: TypeName::Varchar as i32,
258+
..Default::default()
259+
}),
260+
rex_node: Some(RexNode::Constant(ConstantValue {
261+
body: "foo".as_bytes().to_vec(),
262+
})),
263+
},
264+
ExprNode {
265+
expr_type: Type::ConstantValue as i32,
266+
return_type: Some(ProstDataType {
267+
type_name: TypeName::Varchar as i32,
268+
..Default::default()
269+
}),
270+
rex_node: Some(RexNode::Constant(ConstantValue {
271+
body: "bar".as_bytes().to_vec(),
272+
})),
273+
},
274+
],
275+
};
276+
let array_index = FunctionCall {
277+
children: vec![
278+
ExprNode {
279+
expr_type: Type::Array as i32,
280+
return_type: Some(ProstDataType {
281+
type_name: TypeName::List as i32,
282+
field_type: vec![ProstDataType {
283+
type_name: TypeName::Varchar as i32,
284+
..Default::default()
285+
}],
286+
..Default::default()
287+
}),
288+
rex_node: Some(RexNode::FuncCall(values)),
289+
},
290+
ExprNode {
291+
expr_type: Type::ConstantValue as i32,
292+
return_type: Some(ProstDataType {
293+
type_name: TypeName::Int32 as i32,
294+
..Default::default()
295+
}),
296+
rex_node: Some(RexNode::Constant(ConstantValue {
297+
body: vec![0, 0, 0, 1],
298+
})),
299+
},
300+
],
301+
};
302+
let access = ExprNode {
303+
expr_type: Type::ArrayAccess as i32,
304+
return_type: Some(ProstDataType {
305+
type_name: TypeName::Varchar as i32,
306+
..Default::default()
307+
}),
308+
rex_node: Some(RexNode::FuncCall(array_index)),
309+
};
310+
let expr = build_nullable_binary_expr_prost(&access);
311+
assert!(expr.is_ok());
312+
313+
let res = expr.unwrap().eval(&DataChunk::new_dummy(1)).unwrap();
314+
assert_eq!(
315+
*res,
316+
ArrayImpl::Utf8(Utf8Array::from_slice(&[Some("foo")]).unwrap())
317+
);
318+
}
319+
249320
#[test]
250321
fn test_build_in_expr() {
251322
let input_ref = InputRefExpr { column_idx: 0 };

src/expr/src/expr/expr_binary_nullable.rs

+40-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
//! For expression that only accept two nullable arguments as input.
1616
17-
use risingwave_common::array::{Array, BoolArray, Utf8Array};
17+
use risingwave_common::array::*;
1818
use risingwave_common::types::DataType;
1919
use risingwave_pb::expr::expr_node::Type;
2020

2121
use super::BoxedExpression;
2222
use crate::expr::template::BinaryNullableExpression;
2323
use crate::for_all_cmp_variants;
24+
use crate::vector_op::array_access::array_access;
2425
use crate::vector_op::cmp::{general_is_distinct_from, str_is_distinct_from};
2526
use crate::vector_op::conjunction::{and, or};
2627

@@ -62,6 +63,7 @@ pub fn new_nullable_binary_expr(
6263
r: BoxedExpression,
6364
) -> BoxedExpression {
6465
match expr_type {
66+
Type::ArrayAccess => build_array_access_expr(ret, l, r),
6567
Type::And => Box::new(
6668
BinaryNullableExpression::<BoolArray, BoolArray, BoolArray, _>::new(l, r, ret, and),
6769
),
@@ -78,6 +80,43 @@ pub fn new_nullable_binary_expr(
7880
}
7981
}
8082

83+
fn build_array_access_expr(
84+
ret: DataType,
85+
l: BoxedExpression,
86+
r: BoxedExpression,
87+
) -> BoxedExpression {
88+
macro_rules! array_access_expression {
89+
($array:ty) => {
90+
Box::new(
91+
BinaryNullableExpression::<ListArray, I32Array, $array, _>::new(
92+
l,
93+
r,
94+
ret,
95+
array_access,
96+
),
97+
)
98+
};
99+
}
100+
101+
match ret {
102+
DataType::Boolean => array_access_expression!(BoolArray),
103+
DataType::Int16 => array_access_expression!(I16Array),
104+
DataType::Int32 => array_access_expression!(I32Array),
105+
DataType::Int64 => array_access_expression!(I64Array),
106+
DataType::Float32 => array_access_expression!(F32Array),
107+
DataType::Float64 => array_access_expression!(F64Array),
108+
DataType::Decimal => array_access_expression!(DecimalArray),
109+
DataType::Date => array_access_expression!(NaiveDateArray),
110+
DataType::Varchar => array_access_expression!(Utf8Array),
111+
DataType::Time => array_access_expression!(NaiveTimeArray),
112+
DataType::Timestamp => array_access_expression!(NaiveDateTimeArray),
113+
DataType::Timestampz => array_access_expression!(PrimitiveArray::<i64>),
114+
DataType::Interval => array_access_expression!(IntervalArray),
115+
DataType::Struct { .. } => array_access_expression!(StructArray),
116+
DataType::List { .. } => array_access_expression!(ListArray),
117+
}
118+
}
119+
81120
pub fn new_distinct_from_expr(
82121
l: BoxedExpression,
83122
r: BoxedExpression,

src/expr/src/expr/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
8484
Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual | Add
8585
| Subtract | Multiply | Divide | Modulus | Extract | RoundDigit | TumbleStart
8686
| Position => build_binary_expr_prost(prost),
87-
And | Or | IsDistinctFrom => build_nullable_binary_expr_prost(prost),
87+
And | Or | IsDistinctFrom | ArrayAccess => build_nullable_binary_expr_prost(prost),
8888
ToChar => build_to_char_expr(prost),
8989
Coalesce => CoalesceExpression::try_from(prost).map(Expression::boxed),
9090
Substr => build_substr_expr(prost),
+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright 2022 Singularity Data
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use risingwave_common::array::ListRef;
16+
use risingwave_common::error::Result;
17+
use risingwave_common::types::{Scalar, ToOwnedDatum};
18+
19+
#[inline(always)]
20+
pub fn array_access<T: Scalar>(l: Option<ListRef>, r: Option<i32>) -> Result<Option<T>> {
21+
match (l, r) {
22+
// index must be greater than 0 following a one-based numbering convention for arrays
23+
(Some(list), Some(index)) if index > 0 => {
24+
let datumref = list.value_at(index as usize)?;
25+
if let Some(scalar) = datumref.to_owned_datum() {
26+
Ok(Some(scalar.try_into()?))
27+
} else {
28+
Ok(None)
29+
}
30+
}
31+
_ => Ok(None),
32+
}
33+
}
34+
35+
#[cfg(test)]
36+
mod tests {
37+
38+
use risingwave_common::array::ListValue;
39+
use risingwave_common::types::ScalarImpl;
40+
41+
use super::*;
42+
43+
#[test]
44+
fn test_int32_array_access() {
45+
let v1 = ListValue::new(vec![
46+
Some(ScalarImpl::Int32(1)),
47+
Some(ScalarImpl::Int32(2)),
48+
Some(ScalarImpl::Int32(3)),
49+
]);
50+
let l1 = ListRef::ValueRef { val: &v1 };
51+
52+
assert_eq!(array_access::<i32>(Some(l1), Some(1)), Ok(Some(1)));
53+
assert_eq!(array_access::<i32>(Some(l1), Some(-1)), Ok(None));
54+
assert_eq!(array_access::<i32>(Some(l1), Some(0)), Ok(None));
55+
assert_eq!(array_access::<i32>(Some(l1), Some(4)), Ok(None));
56+
}
57+
58+
#[test]
59+
fn test_utf8_array_access() {
60+
let v1 = ListValue::new(vec![
61+
Some(ScalarImpl::Utf8("来自".into())),
62+
Some(ScalarImpl::Utf8("foo".into())),
63+
Some(ScalarImpl::Utf8("bar".into())),
64+
]);
65+
let v2 = ListValue::new(vec![
66+
Some(ScalarImpl::Utf8("fizz".into())),
67+
Some(ScalarImpl::Utf8("荷兰".into())),
68+
Some(ScalarImpl::Utf8("buzz".into())),
69+
]);
70+
let v3 = ListValue::new(vec![None, None, Some(ScalarImpl::Utf8("的爱".into()))]);
71+
72+
let l1 = ListRef::ValueRef { val: &v1 };
73+
let l2 = ListRef::ValueRef { val: &v2 };
74+
let l3 = ListRef::ValueRef { val: &v3 };
75+
76+
assert_eq!(
77+
array_access::<String>(Some(l1), Some(1)),
78+
Ok(Some("来自".into()))
79+
);
80+
assert_eq!(
81+
array_access::<String>(Some(l2), Some(2)),
82+
Ok(Some("荷兰".into()))
83+
);
84+
assert_eq!(
85+
array_access::<String>(Some(l3), Some(3)),
86+
Ok(Some("的爱".into()))
87+
);
88+
}
89+
90+
#[test]
91+
fn test_nested_array_access() {
92+
let v = ListValue::new(vec![
93+
Some(ScalarImpl::List(ListValue::new(vec![
94+
Some(ScalarImpl::Utf8("foo".into())),
95+
Some(ScalarImpl::Utf8("bar".into())),
96+
]))),
97+
Some(ScalarImpl::List(ListValue::new(vec![
98+
Some(ScalarImpl::Utf8("fizz".into())),
99+
Some(ScalarImpl::Utf8("buzz".into())),
100+
]))),
101+
]);
102+
let l = ListRef::ValueRef { val: &v };
103+
assert_eq!(
104+
array_access::<ListValue>(Some(l), Some(1)),
105+
Ok(Some(ListValue::new(vec![
106+
Some(ScalarImpl::Utf8("foo".into())),
107+
Some(ScalarImpl::Utf8("bar".into())),
108+
])))
109+
);
110+
}
111+
}

src/expr/src/vector_op/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
pub mod agg;
1616
pub mod arithmetic_op;
17+
pub mod array_access;
1718
pub mod ascii;
1819
pub mod cast;
1920
pub mod cmp;

src/frontend/src/binder/expr/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ impl Binder {
5151
Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right),
5252
Expr::Nested(expr) => self.bind_expr(*expr),
5353
Expr::Array(exprs) => self.bind_array(exprs),
54+
Expr::ArrayIndex { obj, indexs } => self.bind_array_index(*obj, indexs),
5455
Expr::Function(f) => self.bind_function(f),
5556
// subquery
5657
Expr::Subquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Scalar),

0 commit comments

Comments
 (0)