Skip to content

Commit

Permalink
fix(query): fold constant subquery to build filter plan instead of jo…
Browse files Browse the repository at this point in the history
…in plan (#17448)

* fix(query): fold constant subquery to build filter plan instead of join plan

* check max_inlist_to_or

* fix

* fix
  • Loading branch information
b41sh authored Feb 15, 2025
1 parent 67d4a16 commit 61cdf5a
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ impl MutationExpression {
Arc::new(s_expr),
);

let mut rewriter = SubqueryRewriter::new(binder.metadata.clone(), None);
let mut rewriter =
SubqueryRewriter::new(binder.ctx.clone(), binder.metadata.clone(), None);
let s_expr = rewriter.rewrite(&s_expr)?;

Ok(MutationExpressionBindResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ impl Binder {
let mut is_lateral = false;
if !right_prop.outer_columns.is_empty() {
// If there are outer columns in right child, then the join is a correlated lateral join
let mut decorrelator = SubqueryRewriter::new(self.metadata.clone(), Some(self.clone()));
let mut decorrelator =
SubqueryRewriter::new(self.ctx.clone(), self.metadata.clone(), Some(self.clone()));
right_child = decorrelator.flatten_plan(
&right_child,
&right_prop.outer_columns,
Expand Down
303 changes: 301 additions & 2 deletions src/query/sql/src/planner/optimizer/decorrelate/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeSet;
use std::collections::HashSet;
use std::sync::Arc;

use databend_common_ast::Span;
use databend_common_catalog::table_context::TableContext;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::type_check::common_super_type;
use databend_common_expression::types::DataType;
use databend_common_expression::types::NumberScalar;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::Scalar;
use databend_common_expression::ScalarRef;
use databend_common_functions::BUILTIN_FUNCTIONS;

use crate::binder::ColumnBindingBuilder;
use crate::binder::JoinPredicate;
Expand All @@ -30,12 +39,16 @@ use crate::optimizer::ColumnSet;
use crate::optimizer::RelExpr;
use crate::optimizer::SExpr;
use crate::plans::BoundColumnRef;
use crate::plans::CastExpr;
use crate::plans::ComparisonOp;
use crate::plans::ConstantExpr;
use crate::plans::Filter;
use crate::plans::FunctionCall;
use crate::plans::Join;
use crate::plans::JoinEquiCondition;
use crate::plans::JoinType;
use crate::plans::RelOp;
use crate::plans::RelOperator;
use crate::plans::ScalarExpr;
use crate::plans::SubqueryExpr;
use crate::plans::SubqueryType;
Expand All @@ -53,8 +66,12 @@ use crate::MetadataRef;
/// Correlated exists subquery -> Marker join
///
/// More information can be found in the paper: Unnesting Arbitrary Queries
pub fn decorrelate_subquery(metadata: MetadataRef, s_expr: SExpr) -> Result<SExpr> {
let mut rewriter = SubqueryRewriter::new(metadata, None);
pub fn decorrelate_subquery(
ctx: Arc<dyn TableContext>,
metadata: MetadataRef,
s_expr: SExpr,
) -> Result<SExpr> {
let mut rewriter = SubqueryRewriter::new(ctx, metadata, None);
rewriter.rewrite(&s_expr)
}

Expand Down Expand Up @@ -517,4 +534,286 @@ impl SubqueryRewriter {
true
}))
}

// Try folding the subquery into a constant value expression,
// which turns the join plan into a filter plan, so that the bloom filter
// can be used to reduce the amount of data that needs to be read.
pub fn try_fold_constant_subquery(
&self,
subquery: &SubqueryExpr,
) -> Result<Option<ScalarExpr>> {
// (1) EvalScalar
// \
// DummyTableScan
//
// (2) EvalScalar
// \
// EvalScalar
// \
// ProjectSet
// \
// DummyTableScan
let matchers = vec![
Matcher::MatchOp {
op_type: RelOp::EvalScalar,
children: vec![Matcher::MatchOp {
op_type: RelOp::DummyTableScan,
children: vec![],
}],
},
Matcher::MatchOp {
op_type: RelOp::EvalScalar,
children: vec![Matcher::MatchOp {
op_type: RelOp::EvalScalar,
children: vec![Matcher::MatchOp {
op_type: RelOp::ProjectSet,
children: vec![Matcher::MatchOp {
op_type: RelOp::DummyTableScan,
children: vec![],
}],
}],
}],
},
];

let mut matched = false;
for matcher in matchers {
if matcher.matches(&subquery.subquery) {
matched = true;
break;
}
}
if !matched {
return Ok(None);
}

let child = subquery.subquery.child(0)?;
if let RelOperator::DummyTableScan(_) = child.plan() {
// subquery is a simple constant value.
// for example: `SELECT * FROM t WHERE id = (select 1);`
if let RelOperator::EvalScalar(eval) = subquery.subquery.plan() {
if eval.items.len() != 1 {
return Ok(None);
}
let Ok(const_scalar) = ConstantExpr::try_from(eval.items[0].scalar.clone()) else {
return Ok(None);
};
match (&subquery.child_expr, subquery.compare_op) {
(Some(child_expr), Some(compare_op)) => {
let func_name = compare_op.to_func_name().to_string();
let func = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
func_name,
params: vec![],
arguments: vec![*child_expr.clone(), const_scalar.into()],
});
return Ok(Some(func));
}
(None, None) => match subquery.typ {
SubqueryType::Scalar => {
return Ok(Some(const_scalar.into()));
}
SubqueryType::Exists => {
return Ok(Some(ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value: Scalar::Boolean(true),
})));
}
SubqueryType::NotExists => {
return Ok(Some(ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value: Scalar::Boolean(false),
})));
}
_ => {}
},
(_, _) => {}
}
}
} else {
// subquery is a set returning function return constant values.
// for example: `SELECT * FROM t WHERE id IN (SELECT * FROM UNNEST(SPLIT('1,2,3', ',')) AS t1);`
let mut output_column_index = None;
if let RelOperator::EvalScalar(eval) = subquery.subquery.plan() {
if eval.items.len() != 1 {
return Ok(None);
}
if let ScalarExpr::BoundColumnRef(bound_column) = &eval.items[0].scalar {
output_column_index = Some(bound_column.column.index);
}
}
if output_column_index.is_none() {
return Ok(None);
}
let output_column_index = output_column_index.unwrap();

let mut srf_column_index = None;
if let RelOperator::EvalScalar(eval) = child.plan() {
if eval.items.len() != 1 || eval.items[0].index != output_column_index {
return Ok(None);
}
if let ScalarExpr::FunctionCall(get_func) = &eval.items[0].scalar {
if get_func.func_name == "get"
&& get_func.arguments.len() == 1
&& get_func.params.len() == 1
&& get_func.params[0] == Scalar::Number(NumberScalar::Int64(1))
{
if let ScalarExpr::BoundColumnRef(bound_column) = &get_func.arguments[0] {
srf_column_index = Some(bound_column.column.index);
}
}
}
}
if srf_column_index.is_none() {
return Ok(None);
}
let srf_column_index = srf_column_index.unwrap();

let project_set_expr = child.child(0)?;
if let RelOperator::ProjectSet(project_set) = project_set_expr.plan() {
if project_set.srfs.len() != 1
|| project_set.srfs[0].index != srf_column_index
|| subquery.compare_op != Some(ComparisonOp::Equal)
|| subquery.typ != SubqueryType::Any
{
return Ok(None);
}
let Ok(srf) = FunctionCall::try_from(project_set.srfs[0].scalar.clone()) else {
return Ok(None);
};
if srf.arguments.len() != 1 {
return Ok(None);
}
let Ok(const_scalar) = ConstantExpr::try_from(srf.arguments[0].clone()) else {
return Ok(None);
};
let Some(child_expr) = &subquery.child_expr else {
return Ok(None);
};
match &const_scalar.value {
Scalar::EmptyArray => {
return Ok(Some(ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value: Scalar::Null,
})));
}
Scalar::Array(array_column) => {
let mut values = BTreeSet::new();
for scalar in array_column.iter() {
// Ignoring NULL values in equivalent filter
if scalar == ScalarRef::Null {
continue;
}
values.insert(scalar.to_owned());
}
// If there are no equivalent values, the filter condition does not match,
// return a NULL value.
if values.is_empty() {
return Ok(Some(ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value: Scalar::Null,
})));
}
// If the number of values more than `inlist_to_join_threshold`, need convert to join.
if values.len() >= self.ctx.get_settings().get_inlist_to_join_threshold()? {
return Ok(None);
}
// If the number of values more than `max_inlist_to_or`, use contains function instead of or.
if values.len() > self.ctx.get_settings().get_max_inlist_to_or()? as usize {
let value_type = values.first().unwrap().as_ref().infer_data_type();
let mut builder =
ColumnBuilder::with_capacity(&value_type, values.len());
for value in values.into_iter() {
builder.push(value.as_ref());
}
let array_value = ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value: Scalar::Array(builder.build()),
});

let expr_type = child_expr.data_type()?;
let common_type = common_super_type(
value_type.clone(),
expr_type.clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
)
.ok_or_else(|| {
ErrorCode::IllegalDataType(format!(
"Cannot find common type for inlist subquery value {:?} and expr {:?}",
&array_value, &child_expr
))
})?;

let mut arguments = Vec::with_capacity(2);
if value_type != common_type {
arguments.push(ScalarExpr::CastExpr(CastExpr {
span: subquery.span,
is_try: false,
argument: Box::new(array_value),
target_type: Box::new(DataType::Array(Box::new(
common_type.clone(),
))),
}));
} else {
arguments.push(array_value);
}
if expr_type != common_type {
arguments.push(ScalarExpr::CastExpr(CastExpr {
span: subquery.span,
is_try: false,
argument: Box::new(*child_expr.clone()),
target_type: Box::new(common_type.clone()),
}));
} else {
arguments.push(*child_expr.clone());
}
let func = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
func_name: "contains".to_string(),
params: vec![],
arguments,
});
return Ok(Some(func));
}

let mut funcs = Vec::with_capacity(values.len());
for value in values.into_iter() {
let scalar_value = ScalarExpr::ConstantExpr(ConstantExpr {
span: subquery.span,
value,
});
let func = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
func_name: "eq".to_string(),
params: vec![],
arguments: vec![*child_expr.clone(), scalar_value],
});
funcs.push(func);
}
let or_func = funcs
.into_iter()
.fold(None, |mut acc, func| {
match acc.as_mut() {
None => acc = Some(func),
Some(acc) => {
*acc = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
func_name: "or".to_string(),
params: vec![],
arguments: vec![acc.clone(), func],
});
}
}
acc
})
.unwrap();
return Ok(Some(or_func));
}
_ => {}
}
}
}

Ok(None)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::vec;

use databend_common_catalog::table_context::TableContext;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::types::DataType;
Expand Down Expand Up @@ -69,14 +70,16 @@ pub struct FlattenInfo {

/// Rewrite subquery into `Apply` operator
pub struct SubqueryRewriter {
pub(crate) ctx: Arc<dyn TableContext>,
pub(crate) metadata: MetadataRef,
pub(crate) derived_columns: HashMap<IndexType, IndexType>,
pub(crate) binder: Option<Binder>,
}

impl SubqueryRewriter {
pub fn new(metadata: MetadataRef, binder: Option<Binder>) -> Self {
pub fn new(ctx: Arc<dyn TableContext>, metadata: MetadataRef, binder: Option<Binder>) -> Self {
Self {
ctx,
metadata,
derived_columns: Default::default(),
binder,
Expand Down Expand Up @@ -254,6 +257,10 @@ impl SubqueryRewriter {
let mut subquery = subquery.clone();
subquery.subquery = Box::new(self.rewrite(&subquery.subquery)?);

if let Some(constant_subquery) = self.try_fold_constant_subquery(&subquery)? {
return Ok((constant_subquery, s_expr.clone()));
}

// Check if the subquery is a correlated subquery.
// If it is, we'll try to flatten it and rewrite to join.
// If it is not, we'll just rewrite it to join
Expand Down
Loading

0 comments on commit 61cdf5a

Please sign in to comment.