diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index d4d50ac4eae4..3e3aba46e17f 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -241,3 +241,10 @@ fn make_count_schema() -> DFSchemaRef { .unwrap(), ) } + +#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] +pub struct Merge { + pub input: Arc, + pub join: Arc, + pub case: Arc, +} diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1f1c235fee6f..052ce0f0c0c8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::{BTreeMap, HashMap, HashSet}; -use std::path::Path; -use std::str::FromStr; -use std::sync::Arc; - use crate::parser::{ CopyToSource, CopyToStatement, CreateExternalTable, DFParser, ExplainStatement, LexOrdering, Statement as DFStatement, @@ -27,7 +22,13 @@ use crate::parser::{ use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, }; +use crate::statement::ast::Join; use crate::utils::normalize_ident; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::ops::Not; +use std::path::Path; +use std::str::FromStr; +use std::sync::Arc; use arrow::datatypes::{DataType, Fields}; use datafusion_common::error::_plan_err; @@ -39,25 +40,28 @@ use datafusion_common::{ ToDFSchema, }; use datafusion_expr::dml::{CopyTo, InsertOp}; +use datafusion_expr::expr::WildcardOptions; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; +use datafusion_expr::logical_plan::Join as PlanJoin; use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{ - cast, col, Analyze, CreateCatalog, CreateCatalogSchema, - CreateExternalTable as PlanCreateExternalTable, CreateFunction, CreateFunctionBody, - CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, Deallocate, - DescribeTable, DmlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, - EmptyRelation, Execute, Explain, ExplainFormat, Expr, ExprSchemable, Filter, - LogicalPlan, LogicalPlanBuilder, OperateFunctionArg, PlanType, Prepare, SetVariable, - SortExpr, Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, + cast, col, lit, wildcard_with_options, Analyze, Case, CreateCatalog, + CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable, CreateFunction, + CreateFunctionBody, CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, + Deallocate, DescribeTable, DmlStatement, DropCatalogSchema, DropFunction, DropTable, + DropView, EmptyRelation, Execute, Explain, ExplainFormat, Expr, ExprSchemable, + Filter, JoinConstraint, JoinType, LogicalPlan, LogicalPlanBuilder, + OperateFunctionArg, PlanType, Prepare, SetVariable, SortExpr, + Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; use sqlparser::ast::{ - self, BeginTransactionKind, NullsDistinctOption, ShowStatementIn, - ShowStatementOptions, SqliteOnConflict, TableObject, UpdateTableFromKind, - ValueWithSpan, + self, BeginTransactionKind, JoinOperator, MergeAction, MergeClause, MergeClauseKind, + MergeInsertKind, NullsDistinctOption, ShowStatementIn, ShowStatementOptions, + SqliteOnConflict, TableObject, UpdateTableFromKind, ValueWithSpan, Values, }; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, @@ -999,7 +1003,13 @@ impl SqlToRel<'_, S> { let table_name = self.get_delete_target(from)?; self.delete_to_plan(table_name, selection) } - + Statement::Merge { + into: false, + table, + source, + on, + clauses, + } => self.merge_to_plan(table, source, on, clauses), Statement::StartTransaction { modes, begin: false, @@ -1996,7 +2006,7 @@ impl SqlToRel<'_, S> { // infer types for Values clause... other types should be resolvable the regular way let mut prepare_param_data_types = BTreeMap::new(); - if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() { + if let SetExpr::Values(Values { rows, .. }) = (*source.body).clone() { for row in rows.iter() { for (idx, val) in row.iter().enumerate() { if let SQLExpr::Value(ValueWithSpan { @@ -2074,6 +2084,182 @@ impl SqlToRel<'_, S> { Ok(plan) } + fn merge_to_plan( + &self, + source_table: TableFactor, + target_table: TableFactor, + on: Box, + clauses: Vec, + ) -> Result { + let mut ctx = PlannerContext::new(); + + let target_name: ObjectName; + let source_name: ObjectName; + + match target_table { + TableFactor::Table { + name, + alias, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + sample, + index_hints, + } => { + target_name = name; + } + _ => return plan_err!("Target table can only be a table for MERGE."), + } + + let target_ref = self.object_name_to_table_reference(target_name)?; + let target_src = self.context_provider.get_table_source(target_ref.clone())?; + let target_scan = + LogicalPlanBuilder::scan(target_ref.clone(), Arc::clone(&target_src), None)? + .project(vec![projected_columns, lit(true).alias("target_exists")])? // add flag for matching target + .build()?; + + match source_table { + TableFactor::Table { + name, + alias, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + sample, + index_hints, + } => { + source_name = name; + } + _ => { + return plan_err!("Source table can currently only be a table for MERGE.") + } + } + + let source_ref = self.object_name_to_table_reference(source_name)?; + let source_src = self.context_provider.get_table_source(source_ref.clone())?; + let source_scan = + LogicalPlanBuilder::scan(source_ref.clone(), Arc::clone(&source_src), None)? + .project(vec![projected_columns, lit(true).alias("source_exists")])? // add flag for matching source + .build()?; + + let target_schema = target_scan.schema(); + + let joined_schema = + DFSchema::from(target_scan.schema().join(source_scan.schema())?); + + let on_df_expr = self.sql_to_expr(*on, &joined_schema, &mut ctx)?; + + let join_plan = LogicalPlan::Join(PlanJoin { + left: Arc::new(target_scan.clone()), + right: Arc::new(source_scan.clone()), + on: vec![], + filter: Some(on_df_expr), + join_type: JoinType::Full, + join_constraint: JoinConstraint::On, + schema: Arc::new(target_scan.schema().join(source_scan.schema())?), + null_equals_null: false, + }); + + // Flag checks for both tables + let both_not_null = col("target_exists") + .is_not_null() + .and(col("source_exists").is_not_null()); + let only_source = col("target_exists") + .is_null() + .and(col("source_exists").is_not_null()); + let only_target = col("target_exists") + .is_not_null() + .and(col("source_exists").is_null()); + + let mut when_then: Vec<(Box, Box)> = Vec::new(); + let mut delete_condition = Vec::::new(); + + let mut planner_context = PlannerContext::new(); + + for clause in clauses { + let base = match clause.clause_kind { + MergeClauseKind::Matched => both_not_null.clone(), + MergeClauseKind::NotMatchedByTarget | MergeClauseKind::NotMatched => { + only_source.clone() + } + MergeClauseKind::NotMatchedBySource => only_target.clone(), + }; + + // Combine predicate and column check + let when_expr = if let Some(pred) = &clause.predicate { + let predicate = + self.sql_to_expr(*pred, &joined_schema, &mut planner_context)?; + base.and(predicate) + } else { + base + }; + + match &clause.action { + MergeAction::Update { assignments } => { + // each assignment (col = expr) becomes its own `when -> then` + for assign in assignments { + let value = Box::new(self.sql_to_expr( + assign.value, + &joined_schema, + &mut planner_context, + )?); + when_then.push((Box::new(when_expr.clone()), value)); + } + } + MergeAction::Insert(insert_expr) => match &insert_expr.kind { + MergeInsertKind::Values(Values { rows, .. }) => { + let first_row = &rows[0]; + for (col_ident, val) in insert_expr.columns.iter().zip(first_row) + { + let value = Box::new(self.sql_to_expr( + val.clone(), + &joined_schema, + &mut planner_context, + )?); + when_then.push(( + Box::new(when_expr.clone()), + Box::new(value.clone().alias(&col_ident.value)), + )); + } + } + MergeInsertKind::Row => { + for col_ident in &insert_expr.columns { + let src_col = Expr::Column(col_ident.clone().into()); + when_then + .push((Box::new(when_expr.clone()), Box::new(src_col))); + } + } + }, + + MergeAction::Delete => { + delete_condition.push(when_expr.clone()); + } + } + } + + let delete_pred = delete_condition + .into_iter() + .reduce(|a, b| a.or(b)) + .unwrap_or_else(|| lit(false)); + + let merged = LogicalPlanBuilder::from(join_plan) + .filter(delete_pred.not())? + .project(vec![Expr::Case(Case { + expr: None, + when_then_expr: when_then.clone(), + else_expr: Some(Box::new(col(""))), + })])? + .build()?; + + Ok(merged) + } + fn show_columns_to_plan( &self, extended: bool,