Skip to content

Commit 12da00c

Browse files
authored
feat(optimizer): simplify filter predicate before converting row_number + filter to topn (#22295)
Signed-off-by: Richard Chien <[email protected]>
1 parent 04b350b commit 12da00c

File tree

3 files changed

+305
-1
lines changed

3 files changed

+305
-1
lines changed

src/frontend/planner_test/tests/testdata/input/over_window_function.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,68 @@
393393
- stream_plan
394394
- batch_plan
395395

396+
# TopN with arithmetic on row_number/rank function
397+
- name: test ibis window function optimization
398+
sql: |
399+
CREATE TABLE t (a INT, b INT, c INT);
400+
SELECT a, b, c
401+
FROM (
402+
SELECT
403+
a, b, c,
404+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) - 1 AS rn
405+
FROM t
406+
)
407+
WHERE rn < 10;
408+
expected_outputs:
409+
- logical_plan
410+
- optimized_logical_plan_for_batch
411+
- optimized_logical_plan_for_stream
412+
- name: test arithmetic with addition
413+
sql: |
414+
CREATE TABLE t (a INT, b INT, c INT);
415+
SELECT a, b, c
416+
FROM (
417+
SELECT
418+
a, b, c,
419+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) + 5 AS rn_plus_five
420+
FROM t
421+
)
422+
WHERE rn_plus_five = 6;
423+
expected_outputs:
424+
- logical_plan
425+
- optimized_logical_plan_for_batch
426+
- optimized_logical_plan_for_stream
427+
- name: test optimization when rn column is kept
428+
sql: |
429+
CREATE TABLE t (a INT, b INT, c INT);
430+
SELECT a, b, c, rn_plus_one
431+
FROM (
432+
SELECT
433+
a, b, c,
434+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) + 1 AS rn_plus_one
435+
FROM t
436+
)
437+
WHERE rn_plus_one = 2;
438+
expected_outputs:
439+
- logical_plan
440+
- optimized_logical_plan_for_batch
441+
- optimized_logical_plan_for_stream
442+
- name: test complex arithmetic not optimized
443+
sql: |
444+
CREATE TABLE t (a INT, b INT, c INT);
445+
SELECT a, b, c, rn
446+
FROM (
447+
SELECT
448+
a, b, c,
449+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) * 2 + 1 AS rn
450+
FROM t
451+
)
452+
WHERE rn = 3;
453+
expected_outputs:
454+
- logical_plan
455+
- optimized_logical_plan_for_batch
456+
- optimized_logical_plan_for_stream
457+
396458
# TopN on nexmark schema
397459
- id: create_bid
398460
sql: |

src/frontend/planner_test/tests/testdata/output/over_window_function.yaml

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,110 @@
855855
└─StreamOverWindow { window_functions: [row_number() OVER(PARTITION BY t.x ORDER BY t.y ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
856856
└─StreamExchange { dist: HashShard(t.x) }
857857
└─StreamTableScan { table: t, columns: [t.x, t.y, t.z, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
858+
- name: test ibis window function optimization
859+
sql: |
860+
CREATE TABLE t (a INT, b INT, c INT);
861+
SELECT a, b, c
862+
FROM (
863+
SELECT
864+
a, b, c,
865+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) - 1 AS rn
866+
FROM t
867+
)
868+
WHERE rn < 10;
869+
logical_plan: |-
870+
LogicalProject { exprs: [t.a, t.b, t.c] }
871+
└─LogicalFilter { predicate: ($expr1 < 10:Int32) }
872+
└─LogicalProject { exprs: [t.a, t.b, t.c, (row_number - 1:Int32) as $expr1] }
873+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
874+
└─LogicalProject { exprs: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
875+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
876+
optimized_logical_plan_for_batch: |-
877+
LogicalTopN { order: [t.b ASC], limit: 10, offset: 0, group_key: [t.a] }
878+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
879+
optimized_logical_plan_for_stream: |-
880+
LogicalTopN { order: [t.b ASC], limit: 10, offset: 0, group_key: [t.a] }
881+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
882+
- name: test arithmetic with addition
883+
sql: |
884+
CREATE TABLE t (a INT, b INT, c INT);
885+
SELECT a, b, c
886+
FROM (
887+
SELECT
888+
a, b, c,
889+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) + 5 AS rn_plus_five
890+
FROM t
891+
)
892+
WHERE rn_plus_five = 6;
893+
logical_plan: |-
894+
LogicalProject { exprs: [t.a, t.b, t.c] }
895+
└─LogicalFilter { predicate: ($expr1 = 6:Int32) }
896+
└─LogicalProject { exprs: [t.a, t.b, t.c, (row_number + 5:Int32) as $expr1] }
897+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
898+
└─LogicalProject { exprs: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
899+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
900+
optimized_logical_plan_for_batch: |-
901+
LogicalTopN { order: [t.b ASC], limit: 1, offset: 0, group_key: [t.a] }
902+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
903+
optimized_logical_plan_for_stream: |-
904+
LogicalTopN { order: [t.b ASC], limit: 1, offset: 0, group_key: [t.a] }
905+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
906+
- name: test optimization when rn column is kept
907+
sql: |
908+
CREATE TABLE t (a INT, b INT, c INT);
909+
SELECT a, b, c, rn_plus_one
910+
FROM (
911+
SELECT
912+
a, b, c,
913+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) + 1 AS rn_plus_one
914+
FROM t
915+
)
916+
WHERE rn_plus_one = 2;
917+
logical_plan: |-
918+
LogicalProject { exprs: [t.a, t.b, t.c, $expr1] }
919+
└─LogicalFilter { predicate: ($expr1 = 2:Int32) }
920+
└─LogicalProject { exprs: [t.a, t.b, t.c, (row_number + 1:Int32) as $expr1] }
921+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
922+
└─LogicalProject { exprs: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
923+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
924+
optimized_logical_plan_for_batch: |-
925+
LogicalProject { exprs: [t.a, t.b, t.c, (row_number + 1:Int32) as $expr1] }
926+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
927+
└─LogicalTopN { order: [t.b ASC], limit: 1, offset: 0, group_key: [t.a] }
928+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
929+
optimized_logical_plan_for_stream: |-
930+
LogicalProject { exprs: [t.a, t.b, t.c, (row_number + 1:Int32) as $expr1] }
931+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
932+
└─LogicalTopN { order: [t.b ASC], limit: 1, offset: 0, group_key: [t.a] }
933+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
934+
- name: test complex arithmetic not optimized
935+
sql: |
936+
CREATE TABLE t (a INT, b INT, c INT);
937+
SELECT a, b, c, rn
938+
FROM (
939+
SELECT
940+
a, b, c,
941+
ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) * 2 + 1 AS rn
942+
FROM t
943+
)
944+
WHERE rn = 3;
945+
logical_plan: |-
946+
LogicalProject { exprs: [t.a, t.b, t.c, $expr1] }
947+
└─LogicalFilter { predicate: ($expr1 = 3:Int32) }
948+
└─LogicalProject { exprs: [t.a, t.b, t.c, ((row_number * 2:Int32) + 1:Int32) as $expr1] }
949+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
950+
└─LogicalProject { exprs: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
951+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c, t._row_id, t._rw_timestamp] }
952+
optimized_logical_plan_for_batch: |-
953+
LogicalProject { exprs: [t.a, t.b, t.c, ((row_number * 2:Int32) + 1:Int32) as $expr1] }
954+
└─LogicalFilter { predicate: (((row_number * 2:Int32) + 1:Int32) = 3:Int32) }
955+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
956+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
957+
optimized_logical_plan_for_stream: |-
958+
LogicalProject { exprs: [t.a, t.b, t.c, ((row_number * 2:Int32) + 1:Int32) as $expr1] }
959+
└─LogicalFilter { predicate: (((row_number * 2:Int32) + 1:Int32) = 3:Int32) }
960+
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
961+
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
858962
- id: create_bid
859963
sql: |
860964
/*

src/frontend/src/optimizer/rule/over_window_to_topn_rule.rs

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ use risingwave_expr::window_function::WindowFuncKind;
1818

1919
use super::{BoxedRule, Rule};
2020
use crate::PlanRef;
21-
use crate::expr::{ExprImpl, ExprType, collect_input_refs};
21+
use crate::expr::{
22+
Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, Literal, collect_input_refs,
23+
};
2224
use crate::optimizer::plan_node::generic::GenericPlanRef;
2325
use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
2426
use crate::optimizer::property::Order;
2527
use crate::planner::LIMIT_ALL_COUNT;
28+
use crate::utils::Condition;
2629

2730
/// Transforms the following pattern to group `TopN` (No Ranking Output).
2831
///
@@ -42,6 +45,9 @@ use crate::planner::LIMIT_ALL_COUNT;
4245
/// FROM ..
4346
/// WHERE rank [ < | <= | > | >= | = ] ..;
4447
/// ```
48+
///
49+
/// Also optimizes filter arithmetic expressions in the `Project <- Filter <- OverWindow` pattern,
50+
/// such as simplifying `(row_number - 1) = 0` to `row_number = 1`.
4551
pub struct OverWindowToTopNRule;
4652

4753
impl OverWindowToTopNRule {
@@ -65,6 +71,13 @@ impl Rule for OverWindowToTopNRule {
6571
// The filter is directly on top of the over window after predicate pushdown.
6672
let over_window = plan.as_logical_over_window()?;
6773

74+
// First try to simplify filter arithmetic expressions
75+
let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
76+
simplified
77+
} else {
78+
filter.clone()
79+
};
80+
6881
if over_window.window_functions().len() != 1 {
6982
// Queries with multiple window function calls are not supported yet.
7083
return None;
@@ -137,6 +150,131 @@ impl Rule for OverWindowToTopNRule {
137150
}
138151
}
139152

153+
impl OverWindowToTopNRule {
154+
/// Simplify arithmetic expressions in filter conditions before TopN optimization
155+
/// For example: `(row_number - 1) = 0` -> `row_number = 1`
156+
fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
157+
let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
158+
Some(LogicalFilter::new(filter.input(), new_predicate))
159+
}
160+
161+
/// Simplify arithmetic expressions in the filter condition
162+
fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
163+
let expr = predicate.as_expr_unless_true()?;
164+
let mut rewriter = FilterArithmeticRewriter {};
165+
let new_expr = rewriter.rewrite_expr(expr.clone());
166+
167+
if new_expr != expr {
168+
Some(Condition::with_expr(new_expr))
169+
} else {
170+
None
171+
}
172+
}
173+
}
174+
175+
/// Filter arithmetic simplification rewriter: simplifies `(col op const) = const2` to `col = (const2 reverse_op const)`
176+
struct FilterArithmeticRewriter {}
177+
178+
impl ExprRewriter for FilterArithmeticRewriter {
179+
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
180+
use ExprType::{
181+
Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
182+
};
183+
184+
// Check if this is a comparison operation
185+
match func_call.func_type() {
186+
Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
187+
let inputs = func_call.inputs();
188+
if inputs.len() == 2 {
189+
// Check if left operand is an arithmetic expression and right operand is a constant
190+
if let ExprImpl::FunctionCall(left_func) = &inputs[0] {
191+
if inputs[1].is_const() {
192+
if let Some(simplified) = self.simplify_arithmetic_comparison(
193+
left_func,
194+
&inputs[1],
195+
func_call.func_type(),
196+
) {
197+
return simplified;
198+
}
199+
}
200+
}
201+
}
202+
}
203+
_ => {}
204+
}
205+
206+
// Recursively handle sub-expressions
207+
let (func_type, inputs, ret_type) = func_call.decompose();
208+
let new_inputs: Vec<_> = inputs
209+
.into_iter()
210+
.map(|input| self.rewrite_expr(input))
211+
.collect();
212+
213+
FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
214+
}
215+
}
216+
217+
impl FilterArithmeticRewriter {
218+
/// Simplify arithmetic comparison: `(col op const1) comp const2` -> `col comp (const2 reverse_op const1)`
219+
fn simplify_arithmetic_comparison(
220+
&self,
221+
arithmetic_func: &FunctionCall,
222+
comparison_const: &ExprImpl,
223+
comparison_op: ExprType,
224+
) -> Option<ExprImpl> {
225+
use ExprType::{Add, Subtract};
226+
227+
// Check arithmetic operation
228+
match arithmetic_func.func_type() {
229+
Add | Subtract => {
230+
let inputs = arithmetic_func.inputs();
231+
if inputs.len() == 2 {
232+
// Find column reference and constant
233+
let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
234+
// col op const
235+
let reverse_op = match arithmetic_func.func_type() {
236+
Add => Subtract,
237+
Subtract => Add,
238+
_ => unreachable!(),
239+
};
240+
(&inputs[0], &inputs[1], reverse_op)
241+
} else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
242+
// const + col
243+
(&inputs[1], &inputs[0], Subtract)
244+
} else {
245+
return None;
246+
};
247+
248+
// Calculate new constant value
249+
if let Ok(new_const_func) = FunctionCall::new(
250+
reverse_op,
251+
vec![comparison_const.clone(), arith_const.clone()],
252+
) {
253+
let new_const_expr: ExprImpl = new_const_func.into();
254+
// Try constant folding
255+
if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
256+
let new_const =
257+
Literal::new(Some(folded_value), new_const_expr.return_type())
258+
.into();
259+
260+
// Construct new comparison expression
261+
if let Ok(new_comparison) = FunctionCall::new(
262+
comparison_op,
263+
vec![column_ref.clone(), new_const],
264+
) {
265+
return Some(new_comparison.into());
266+
}
267+
}
268+
}
269+
}
270+
}
271+
_ => {}
272+
}
273+
274+
None
275+
}
276+
}
277+
140278
/// Returns `None` if the conditions are too complex or invalid. `Some((limit, offset))` otherwise.
141279
fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
142280
if rank_preds.is_empty() {

0 commit comments

Comments
 (0)