Skip to content

Commit 7f11268

Browse files
committed
[WIP] Differentiate (boolean-valued) predicates from normal expressions
1 parent ec419a2 commit 7f11268

File tree

19 files changed

+831
-373
lines changed

19 files changed

+831
-373
lines changed

Diff for: ffi/src/expressions/engine.rs

+65-12
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ use delta_kernel::expressions::{
1111
};
1212
use delta_kernel::DeltaResult;
1313

14+
pub enum ExpressionOrPredicate {
15+
Expression(Expression),
16+
Predicate(Predicate),
17+
}
18+
1419
#[derive(Default)]
1520
pub struct KernelExpressionVisitorState {
16-
inflight_ids: ReferenceSet<Expression>,
21+
inflight_ids: ReferenceSet<ExpressionOrPredicate>,
1722
}
1823

1924
/// A predicate that can be used to skip data when scanning.
@@ -36,27 +41,35 @@ pub struct EnginePredicate {
3641
}
3742

3843
fn wrap_expression(state: &mut KernelExpressionVisitorState, expr: impl Into<Expression>) -> usize {
39-
state.inflight_ids.insert(expr.into())
44+
state
45+
.inflight_ids
46+
.insert(ExpressionOrPredicate::Expression(expr.into()))
4047
}
4148

42-
fn wrap_predicate(state: &mut KernelExpressionVisitorState, expr: impl Into<Predicate>) -> usize {
43-
// TODO: Actually split this out
44-
wrap_expression(state, expr)
49+
fn wrap_predicate(state: &mut KernelExpressionVisitorState, pred: impl Into<Predicate>) -> usize {
50+
state
51+
.inflight_ids
52+
.insert(ExpressionOrPredicate::Predicate(pred.into()))
4553
}
4654

47-
pub fn unwrap_kernel_expression(
55+
pub(crate) fn unwrap_kernel_expression(
4856
state: &mut KernelExpressionVisitorState,
4957
exprid: usize,
5058
) -> Option<Expression> {
51-
state.inflight_ids.take(exprid)
59+
match state.inflight_ids.take(exprid)? {
60+
ExpressionOrPredicate::Expression(expr) => Some(expr),
61+
ExpressionOrPredicate::Predicate(pred) => Some(Expression::predicate(pred)),
62+
}
5263
}
5364

54-
pub fn unwrap_kernel_predicate(
65+
pub(crate) fn unwrap_kernel_predicate(
5566
state: &mut KernelExpressionVisitorState,
5667
exprid: usize,
5768
) -> Option<Predicate> {
58-
// TODO: Actually split this out
59-
unwrap_kernel_expression(state, exprid)
69+
match state.inflight_ids.take(exprid)? {
70+
ExpressionOrPredicate::Expression(expr) => Some(Predicate::expression(expr)),
71+
ExpressionOrPredicate::Predicate(pred) => Some(pred),
72+
}
6073
}
6174

6275
fn visit_expression_binary(
@@ -79,8 +92,12 @@ fn visit_predicate_binary(
7992
a: usize,
8093
b: usize,
8194
) -> usize {
82-
// TODO: Actually split this out
83-
visit_expression_binary(state, op, a, b)
95+
let left = unwrap_kernel_expression(state, a);
96+
let right = unwrap_kernel_expression(state, b);
97+
match left.zip(right) {
98+
Some((left, right)) => wrap_predicate(state, Predicate::binary(op, left, right)),
99+
None => 0, // invalid child => invalid node
100+
}
84101
}
85102

86103
fn visit_predicate_unary(
@@ -104,6 +121,42 @@ pub extern "C" fn visit_predicate_and(
104121
wrap_predicate(state, result)
105122
}
106123

124+
#[no_mangle]
125+
pub extern "C" fn visit_expression_plus(
126+
state: &mut KernelExpressionVisitorState,
127+
a: usize,
128+
b: usize,
129+
) -> usize {
130+
visit_expression_binary(state, BinaryExpressionOp::Plus, a, b)
131+
}
132+
133+
#[no_mangle]
134+
pub extern "C" fn visit_expression_minus(
135+
state: &mut KernelExpressionVisitorState,
136+
a: usize,
137+
b: usize,
138+
) -> usize {
139+
visit_expression_binary(state, BinaryExpressionOp::Minus, a, b)
140+
}
141+
142+
#[no_mangle]
143+
pub extern "C" fn visit_expression_multiply(
144+
state: &mut KernelExpressionVisitorState,
145+
a: usize,
146+
b: usize,
147+
) -> usize {
148+
visit_expression_binary(state, BinaryExpressionOp::Multiply, a, b)
149+
}
150+
151+
#[no_mangle]
152+
pub extern "C" fn visit_expression_divide(
153+
state: &mut KernelExpressionVisitorState,
154+
a: usize,
155+
b: usize,
156+
) -> usize {
157+
visit_expression_binary(state, BinaryExpressionOp::Divide, a, b)
158+
}
159+
107160
#[no_mangle]
108161
pub extern "C" fn visit_predicate_lt(
109162
state: &mut KernelExpressionVisitorState,

Diff for: ffi/src/expressions/kernel.rs

+36-22
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use std::ffi::c_void;
55

66
use crate::{handle::Handle, kernel_string_slice, KernelStringSlice};
77
use delta_kernel::expressions::{
8-
ArrayData, BinaryExpression, BinaryExpressionOp, BinaryPredicateOp, Expression,
9-
JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, StructData, UnaryPredicate,
10-
UnaryPredicateOp,
8+
ArrayData, BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp,
9+
Expression, JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, StructData,
10+
UnaryPredicate, UnaryPredicateOp,
1111
};
1212

1313
/// Free the memory the passed SharedExpression
@@ -392,6 +392,7 @@ fn visit_expression_impl(
392392
call!(visitor, visit_column, sibling_list_id, name);
393393
}
394394
Expression::Struct(exprs) => visit_expression_struct(visitor, exprs, sibling_list_id),
395+
Expression::Predicate(pred) => visit_predicate_impl(visitor, pred, sibling_list_id),
395396
Expression::Binary(BinaryExpression { op, left, right }) => {
396397
let child_list_id = call!(visitor, make_field_list, 2);
397398
visit_expression_impl(visitor, left, child_list_id);
@@ -401,6 +402,37 @@ fn visit_expression_impl(
401402
BinaryExpressionOp::Minus => visitor.visit_minus,
402403
BinaryExpressionOp::Multiply => visitor.visit_multiply,
403404
BinaryExpressionOp::Divide => visitor.visit_divide,
405+
};
406+
op(visitor.data, sibling_list_id, child_list_id);
407+
}
408+
}
409+
}
410+
411+
fn visit_predicate_impl(
412+
visitor: &mut EngineExpressionVisitor,
413+
predicate: &Predicate,
414+
sibling_list_id: usize,
415+
) {
416+
match predicate {
417+
Predicate::BooleanExpression(expr) => visit_expression_impl(visitor, expr, sibling_list_id),
418+
Predicate::Not(pred) => {
419+
let child_list_id = call!(visitor, make_field_list, 1);
420+
visit_predicate_impl(visitor, pred, child_list_id);
421+
call!(visitor, visit_not, sibling_list_id, child_list_id);
422+
}
423+
Predicate::Unary(UnaryPredicate { op, expr }) => {
424+
let child_list_id = call!(visitor, make_field_list, 1);
425+
visit_expression_impl(visitor, expr, child_list_id);
426+
let op = match op {
427+
UnaryPredicateOp::IsNull => visitor.visit_is_null,
428+
};
429+
op(visitor.data, sibling_list_id, child_list_id);
430+
}
431+
Predicate::Binary(BinaryPredicate { op, left, right }) => {
432+
let child_list_id = call!(visitor, make_field_list, 2);
433+
visit_expression_impl(visitor, left, child_list_id);
434+
visit_expression_impl(visitor, right, child_list_id);
435+
let op = match op {
404436
BinaryPredicateOp::LessThan => visitor.visit_lt,
405437
BinaryPredicateOp::LessThanOrEqual => visitor.visit_le,
406438
BinaryPredicateOp::GreaterThan => visitor.visit_gt,
@@ -413,31 +445,13 @@ fn visit_expression_impl(
413445
};
414446
op(visitor.data, sibling_list_id, child_list_id);
415447
}
416-
Predicate::Unary(UnaryPredicate { op, expr }) => {
417-
let child_id_list = call!(visitor, make_field_list, 1);
418-
visit_expression_impl(visitor, expr, child_id_list);
419-
let op = match op {
420-
UnaryPredicateOp::Not => visitor.visit_not,
421-
UnaryPredicateOp::IsNull => visitor.visit_is_null,
422-
};
423-
op(visitor.data, sibling_list_id, child_id_list);
424-
}
425448
Predicate::Junction(JunctionPredicate { op, preds }) => {
426449
visit_predicate_junction(visitor, op, preds, sibling_list_id)
427450
}
428451
}
429452
}
430453

431-
fn visit_predicate_impl(
432-
visitor: &mut EngineExpressionVisitor,
433-
predicate: &Predicate,
434-
sibling_list_id: usize,
435-
) {
436-
// TODO: Actually split this out
437-
visit_expression_impl(visitor, predicate, sibling_list_id)
438-
}
439-
440-
pub fn visit_expression_internal(
454+
fn visit_expression_internal(
441455
expression: &Expression,
442456
visitor: &mut EngineExpressionVisitor,
443457
) -> usize {

Diff for: ffi/src/test_ffi.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub unsafe extern "C" fn get_testing_kernel_expression() -> Handle<SharedExpress
4242
)
4343
.unwrap();
4444

45-
let mut sub_exprs = vec![
45+
let mut sub_exprs: Vec<Pred> = [
4646
Expr::literal(i8::MAX),
4747
Expr::literal(i8::MIN),
4848
Expr::literal(f32::MAX),
@@ -66,17 +66,28 @@ pub unsafe extern "C" fn get_testing_kernel_expression() -> Handle<SharedExpress
6666
Scalar::Struct(top_level_struct).into(),
6767
Scalar::Array(array_data).into(),
6868
Expr::struct_from(vec![Pred::or_from(vec![
69-
Scalar::Integer(5).into(),
70-
Scalar::Long(20).into(),
71-
])]),
72-
Pred::is_not_null(column_expr!("col")),
73-
];
69+
Pred::expression(Scalar::Integer(5)),
70+
Pred::expression(Scalar::Long(20)),
71+
])
72+
.into()]),
73+
Pred::is_not_null(column_expr!("col")).into(),
74+
]
75+
.into_iter()
76+
.map(Pred::from)
77+
.collect();
7478
sub_exprs.extend(
7579
[
7680
BinaryExpressionOp::Divide,
7781
BinaryExpressionOp::Multiply,
7882
BinaryExpressionOp::Plus,
7983
BinaryExpressionOp::Minus,
84+
]
85+
.into_iter()
86+
.map(|op| Expr::binary(op, Scalar::Integer(0), Scalar::Long(0)))
87+
.map(Pred::from),
88+
);
89+
sub_exprs.extend(
90+
[
8091
BinaryPredicateOp::In,
8192
BinaryPredicateOp::Equal,
8293
BinaryPredicateOp::NotEqual,
@@ -91,5 +102,5 @@ pub unsafe extern "C" fn get_testing_kernel_expression() -> Handle<SharedExpress
91102
.map(|op| Pred::binary(op, Scalar::Integer(0), Scalar::Long(0))),
92103
);
93104

94-
Arc::new(Pred::and_from(sub_exprs)).into()
105+
Arc::new(Expr::from(Pred::and_from(sub_exprs))).into()
95106
}

0 commit comments

Comments
 (0)