Skip to content

Commit bbdb49d

Browse files
authored
fix(optimizer): decorrelate SimpleAgg with array_agg/jsonb_agg/jsonb_object_agg (#15590) (#15616)
1 parent aabd9f6 commit bbdb49d

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

e2e_test/batch/subquery/subquery.slt.part

+7
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ select a, (select count(*) from t1 where t1.a <> t.b) from t1 as t order by 1;
132132
2 2
133133
NULL 0
134134

135+
query II
136+
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
137+
----
138+
1 NULL
139+
2 {2}
140+
NULL NULL
141+
135142
statement ok
136143
drop table t1;
137144

src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@
4646
expected_outputs:
4747
- optimized_logical_plan_for_batch
4848
- logical_plan
49+
- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735'
50+
sql: |
51+
create table t1(a int, b int);
52+
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
53+
expected_outputs:
54+
- logical_plan
55+
- optimized_logical_plan_for_batch
4956
- sql: |
5057
create table t1(x int, y int);
5158
create table t2(x int, y int);

src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml

+54-28
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,31 @@
154154
│ └─LogicalScan { table: t1, columns: [t1.y] }
155155
└─LogicalProject { exprs: [t2.y, 1:Int32] }
156156
└─LogicalScan { table: t2, columns: [t2.y], predicate: IsNotNull(t2.y) }
157+
- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735'
158+
sql: |
159+
create table t1(a int, b int);
160+
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
161+
logical_plan: |-
162+
LogicalProject { exprs: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
163+
└─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true }
164+
├─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] }
165+
└─LogicalProject { exprs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
166+
└─LogicalAgg { aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
167+
└─LogicalProject { exprs: [t1.a] }
168+
└─LogicalFilter { predicate: (t1.a <> CorrelatedInputRef { index: 1, correlated_id: 1 }) }
169+
└─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] }
170+
optimized_logical_plan_for_batch: |-
171+
LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] }
172+
├─LogicalScan { table: t1, columns: [t1.a, t1.b] }
173+
└─LogicalAgg { group_key: [t1.b], aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] }
174+
└─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.b, t1.a, 1:Int32] }
175+
├─LogicalAgg { group_key: [t1.b], aggs: [] }
176+
│ └─LogicalScan { table: t1, columns: [t1.b] }
177+
└─LogicalProject { exprs: [t1.b, t1.a, 1:Int32] }
178+
└─LogicalJoin { type: Inner, on: (t1.a <> t1.b), output: all }
179+
├─LogicalAgg { group_key: [t1.b], aggs: [] }
180+
│ └─LogicalScan { table: t1, columns: [t1.b] }
181+
└─LogicalScan { table: t1, columns: [t1.a] }
157182
- sql: |
158183
create table t1(x int, y int);
159184
create table t2(x int, y int);
@@ -981,14 +1006,14 @@
9811006
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1] }
9821007
├─BatchExchange { order: [], dist: HashShard(t1.b) }
9831008
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
984-
└─BatchProject { exprs: [t1.b, Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1] }
985-
└─BatchHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c)] }
986-
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c] }
1009+
└─BatchProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
1010+
└─BatchHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32))] }
1011+
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32] }
9871012
├─BatchHashAgg { group_key: [t1.b], aggs: [] }
9881013
│ └─BatchExchange { order: [], dist: HashShard(t1.b) }
9891014
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
9901015
└─BatchExchange { order: [], dist: HashShard(t2.d) }
991-
└─BatchProject { exprs: [t2.d, t2.c] }
1016+
└─BatchProject { exprs: [t2.d, t2.c, 1:Int32] }
9921017
└─BatchFilter { predicate: IsNotNull(t2.d) }
9931018
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
9941019
stream_plan: |-
@@ -997,15 +1022,15 @@
9971022
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1, t1._row_id, t1.b, t1.b] }
9981023
├─StreamExchange { dist: HashShard(t1.b) }
9991024
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
1000-
└─StreamProject { exprs: [t1.b, Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1] }
1001-
└─StreamHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c), count] }
1002-
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, t2._row_id] }
1025+
└─StreamProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
1026+
└─StreamHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32)), count] }
1027+
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32, t2._row_id] }
10031028
├─StreamProject { exprs: [t1.b] }
10041029
│ └─StreamHashAgg { group_key: [t1.b], aggs: [count] }
10051030
│ └─StreamExchange { dist: HashShard(t1.b) }
10061031
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) }
10071032
└─StreamExchange { dist: HashShard(t2.d) }
1008-
└─StreamProject { exprs: [t2.d, t2.c, t2._row_id] }
1033+
└─StreamProject { exprs: [t2.d, t2.c, 1:Int32, t2._row_id] }
10091034
└─StreamFilter { predicate: IsNotNull(t2.d) }
10101035
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) }
10111036
- name: correlated array subquery \du
@@ -1030,28 +1055,29 @@
10301055
├─BatchExchange { order: [], dist: HashShard(rw_users.id) }
10311056
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
10321057
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name, rw_users.is_super, rw_users.create_db, rw_users.create_user, rw_users.can_login], distribution: Single }
1033-
└─BatchProject { exprs: [rw_users.id, Coalesce(array_agg(rw_users.name), ARRAY[]:List(Varchar)) as $expr1] }
1034-
└─BatchHashAgg { group_key: [rw_users.id], aggs: [array_agg(rw_users.name)] }
1035-
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: [rw_users.id, rw_users.name] }
1058+
└─BatchProject { exprs: [rw_users.id, Coalesce(array_agg(rw_users.name) filter(IsNotNull(1:Int32)), ARRAY[]:List(Varchar)) as $expr1] }
1059+
└─BatchHashAgg { group_key: [rw_users.id], aggs: [array_agg(rw_users.name) filter(IsNotNull(1:Int32))] }
1060+
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: [rw_users.id, rw_users.name, 1:Int32] }
10361061
├─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
10371062
│ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
10381063
│ └─BatchProject { exprs: [rw_users.id] }
10391064
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
10401065
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
10411066
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
1042-
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.id, rw_users.name] }
1043-
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
1044-
│ └─BatchProject { exprs: [rw_users.id, null:Int32] }
1045-
│ └─BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
1046-
│ ├─BatchExchange { order: [], dist: Single }
1047-
│ │ └─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
1048-
│ │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
1049-
│ │ └─BatchProject { exprs: [rw_users.id] }
1050-
│ │ └─BatchFilter { predicate: (null:Int32 = rw_users.id) AND Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
1051-
│ │ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
1052-
│ └─BatchValues { rows: [] }
1053-
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
1054-
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
1067+
└─BatchProject { exprs: [rw_users.id, rw_users.name, 1:Int32] }
1068+
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.id, rw_users.name] }
1069+
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
1070+
│ └─BatchProject { exprs: [rw_users.id, null:Int32] }
1071+
│ └─BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
1072+
│ ├─BatchExchange { order: [], dist: Single }
1073+
│ │ └─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
1074+
│ │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
1075+
│ │ └─BatchProject { exprs: [rw_users.id] }
1076+
│ │ └─BatchFilter { predicate: (null:Int32 = rw_users.id) AND Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
1077+
│ │ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
1078+
│ └─BatchValues { rows: [] }
1079+
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
1080+
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
10551081
- name: correlated array subquery (issue 14423)
10561082
sql: |
10571083
CREATE TABLE array_types ( x BIGINT[] );
@@ -1066,14 +1092,14 @@
10661092
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [$expr1] }
10671093
├─BatchExchange { order: [], dist: HashShard(array_types.x) }
10681094
│ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }
1069-
└─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x), ARRAY[]:List(List(Int64))) as $expr1] }
1070-
└─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x)] }
1071-
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x] }
1095+
└─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x) filter(IsNotNull(1:Int32)), ARRAY[]:List(List(Int64))) as $expr1] }
1096+
└─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x) filter(IsNotNull(1:Int32))] }
1097+
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x, 1:Int32] }
10721098
├─BatchHashAgg { group_key: [array_types.x], aggs: [] }
10731099
│ └─BatchExchange { order: [], dist: HashShard(array_types.x) }
10741100
│ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }
10751101
└─BatchExchange { order: [], dist: HashShard(array_types.x) }
1076-
└─BatchProject { exprs: [array_types.x, array_types.x] }
1102+
└─BatchProject { exprs: [array_types.x, array_types.x, 1:Int32] }
10771103
└─BatchHashAgg { group_key: [array_types.x], aggs: [] }
10781104
└─BatchExchange { order: [], dist: HashShard(array_types.x) }
10791105
└─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }

src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs

+43-3
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,49 @@ impl Rule for ApplyAggTransposeRule {
140140
// convert count(*) to count(1).
141141
let pos_of_constant_column = node.schema().len() - 1;
142142
agg_calls.iter_mut().for_each(|agg_call| {
143-
if agg_call.agg_kind == AggKind::Count && agg_call.inputs.is_empty() {
144-
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
145-
agg_call.inputs.push(input_ref);
143+
match agg_call.agg_kind {
144+
AggKind::Count if agg_call.inputs.is_empty() => {
145+
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
146+
agg_call.inputs.push(input_ref);
147+
}
148+
AggKind::ArrayAgg
149+
| AggKind::JsonbAgg
150+
| AggKind::JsonbObjectAgg => {
151+
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
152+
let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap();
153+
agg_call.filter.conjunctions.push(cond.into());
154+
}
155+
AggKind::Count
156+
| AggKind::Sum
157+
| AggKind::Sum0
158+
| AggKind::Avg
159+
| AggKind::Min
160+
| AggKind::Max
161+
| AggKind::BitAnd
162+
| AggKind::BitOr
163+
| AggKind::BitXor
164+
| AggKind::BoolAnd
165+
| AggKind::BoolOr
166+
| AggKind::StringAgg
167+
// not in PostgreSQL
168+
| AggKind::ApproxCountDistinct
169+
| AggKind::FirstValue
170+
| AggKind::LastValue
171+
| AggKind::InternalLastSeenValue
172+
// All statistical aggregates only consider non-null inputs.
173+
| AggKind::VarPop
174+
| AggKind::VarSamp
175+
| AggKind::StddevPop
176+
| AggKind::StddevSamp
177+
// All ordered-set aggregates ignore null values in their aggregated input.
178+
| AggKind::PercentileCont
179+
| AggKind::PercentileDisc
180+
| AggKind::Mode
181+
// `grouping` has no *aggregate* input and unreachable when `is_scalar_agg`.
182+
| AggKind::Grouping
183+
=> {
184+
// no-op when `agg(0 rows) == agg(1 row of nulls)`
185+
}
146186
}
147187
});
148188
}

0 commit comments

Comments
 (0)