diff --git a/kernel/src/engine/arrow_expression/evaluate_expression.rs b/kernel/src/engine/arrow_expression/evaluate_expression.rs index 2902bea79c..7bbb0f2c0e 100644 --- a/kernel/src/engine/arrow_expression/evaluate_expression.rs +++ b/kernel/src/engine/arrow_expression/evaluate_expression.rs @@ -1185,13 +1185,10 @@ mod tests { // Create coalesce expression with column that has no nulls, followed by // a reference to a non-existent column. If short-circuit works, the // non-existent column is never evaluated and no error occurs. - let expr = Expression::variadic( - VariadicExpressionOp::Coalesce, - vec![ - Expression::column(["a"]), - Expression::column(["nonexistent"]), // Would fail if evaluated - ], - ); + let expr = Expression::coalesce([ + Expression::column(["a"]), + Expression::column(["nonexistent"]), // Would fail if evaluated + ]); // Should return column "a" directly (short-circuit skips evaluating "nonexistent") let result = evaluate_expression(&expr, &batch, Some(&DataType::INTEGER)).unwrap(); @@ -1216,14 +1213,11 @@ mod tests { // Create coalesce expression: a has nulls, b has none, c doesn't exist. // Short-circuit should stop after evaluating b. - let expr = Expression::variadic( - VariadicExpressionOp::Coalesce, - vec![ - Expression::column(["a"]), - Expression::column(["b"]), - Expression::column(["nonexistent"]), // Would fail if evaluated - ], - ); + let expr = Expression::coalesce([ + Expression::column(["a"]), + Expression::column(["b"]), + Expression::column(["nonexistent"]), // Would fail if evaluated + ]); // Should coalesce a and b, never evaluate "nonexistent" let result = evaluate_expression(&expr, &batch, Some(&DataType::INTEGER)).unwrap(); @@ -1242,10 +1236,7 @@ mod tests { let a_values = Int32Array::from(vec![1, 2, 3]); // No nulls - would short-circuit let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a_values)]).unwrap(); - let expr = Expression::variadic( - VariadicExpressionOp::Coalesce, - vec![Expression::column(["a"])], - ); + let expr = Expression::coalesce([Expression::column(["a"])]); // Request STRING type but array is INT32 - should fail even with short-circuit let result = evaluate_expression(&expr, &batch, Some(&DataType::STRING)); diff --git a/kernel/src/expressions/mod.rs b/kernel/src/expressions/mod.rs index f0ffa7c755..e8e486f90e 100644 --- a/kernel/src/expressions/mod.rs +++ b/kernel/src/expressions/mod.rs @@ -701,6 +701,14 @@ impl Expression { Self::Variadic(VariadicExpression::new(op, exprs)) } + /// Creates a new COALESCE expression that returns the first non-null value. + /// + /// COALESCE evaluates expressions in order and returns the first non-null result. + /// If all expressions evaluate to null, the result is null. + pub fn coalesce(exprs: impl IntoIterator>) -> Self { + Self::variadic(VariadicExpressionOp::Coalesce, exprs) + } + /// Creates a new opaque expression pub fn opaque( op: impl OpaqueExpressionOp, @@ -1158,7 +1166,7 @@ mod tests { use crate::expressions::scalars::{ArrayData, DecimalData, MapData, StructData}; use crate::expressions::{ column_expr, column_name, BinaryExpressionOp, BinaryPredicateOp, ColumnName, - Expression, Predicate, Scalar, Transform, UnaryExpressionOp, VariadicExpressionOp, + Expression, Predicate, Scalar, Transform, UnaryExpressionOp, }; use crate::schema::{ArrayType, DataType, DecimalType, MapType, StructField}; use crate::utils::test_utils::assert_result_error_with_message; @@ -1297,14 +1305,11 @@ mod tests { #[test] fn test_variadic_expression_roundtrip() { - let expr = Expression::variadic( - VariadicExpressionOp::Coalesce, - [ - column_expr!("a"), - column_expr!("b"), - Expression::literal("default"), - ], - ); + let expr = Expression::coalesce([ + column_expr!("a"), + column_expr!("b"), + Expression::literal("default"), + ]); assert_roundtrip(&expr); } @@ -1505,10 +1510,7 @@ mod tests { column_expr!("c"), column_expr!("d"), ); - let coalesce = Expression::variadic( - VariadicExpressionOp::Coalesce, - [add, mul, Expression::literal(0)], - ); + let coalesce = Expression::coalesce([add, mul, Expression::literal(0)]); let pred = Predicate::gt(coalesce, Expression::literal(100)); assert_roundtrip(&pred); diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index d0b864c9bd..23ffb1d7f0 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -560,7 +560,7 @@ mod tests { use crate::actions::get_commit_schema; use crate::engine::sync::SyncEngine; - use crate::expressions::{BinaryExpressionOp, Scalar, VariadicExpressionOp}; + use crate::expressions::{BinaryExpressionOp, Scalar}; use crate::log_replay::ActionsBatch; use crate::scan::state::ScanFile; use crate::scan::state_info::tests::{ @@ -761,17 +761,14 @@ mod tests { assert!(row_id_transform.is_replace); assert_eq!(row_id_transform.exprs.len(), 1); let expr = &row_id_transform.exprs[0]; - let expeceted_expr = Arc::new(Expr::variadic( - VariadicExpressionOp::Coalesce, - vec![ - Expr::column(["row_id_col"]), - Expr::binary( - BinaryExpressionOp::Plus, - Expr::literal(42i64), - Expr::column(["row_indexes_for_row_id_0"]), - ), - ], - )); + let expeceted_expr = Arc::new(Expr::coalesce([ + Expr::column(["row_id_col"]), + Expr::binary( + BinaryExpressionOp::Plus, + Expr::literal(42i64), + Expr::column(["row_indexes_for_row_id_0"]), + ), + ])); assert_eq!(expr, &expeceted_expr); } else { panic!("Should have been a transform expression"); diff --git a/kernel/src/transforms.rs b/kernel/src/transforms.rs index 6be7af871e..564b1a47cb 100644 --- a/kernel/src/transforms.rs +++ b/kernel/src/transforms.rs @@ -10,9 +10,7 @@ use std::sync::Arc; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::expressions::{ - BinaryExpressionOp, Expression, ExpressionRef, Scalar, Transform, VariadicExpressionOp, -}; +use crate::expressions::{BinaryExpressionOp, Expression, ExpressionRef, Scalar, Transform}; use crate::schema::{DataType, SchemaRef, StructType}; use crate::table_features::ColumnMappingMode; use crate::{DeltaResult, Error}; @@ -137,17 +135,14 @@ pub(crate) fn get_transform_expr( let base_row_id = base_row_id.ok_or_else(|| { Error::generic("Asked to generate RowIds, but no baseRowId found.") })?; - let expr = Arc::new(Expression::variadic( - VariadicExpressionOp::Coalesce, - vec![ - Expression::column([field_name]), - Expression::binary( - BinaryExpressionOp::Plus, - Expression::literal(base_row_id), - Expression::column([row_index_field_name]), - ), - ], - )); + let expr = Arc::new(Expression::coalesce([ + Expression::column([field_name]), + Expression::binary( + BinaryExpressionOp::Plus, + Expression::literal(base_row_id), + Expression::column([row_index_field_name]), + ), + ])); transform.with_replaced_field(field_name.clone(), expr) } MetadataDerivedColumn { @@ -592,17 +587,14 @@ mod tests { .expect("Should have row_id_col transform"); assert!(row_id_transform.is_replace); - let expeceted_expr = Arc::new(Expression::variadic( - VariadicExpressionOp::Coalesce, - vec![ - Expression::column(["row_id_col"]), - Expression::binary( - BinaryExpressionOp::Plus, - Expression::literal(4i64), - Expression::column(["row_index_col"]), - ), - ], - )); + let expeceted_expr = Arc::new(Expression::coalesce([ + Expression::column(["row_id_col"]), + Expression::binary( + BinaryExpressionOp::Plus, + Expression::literal(4i64), + Expression::column(["row_index_col"]), + ), + ])); assert_eq!(row_id_transform.exprs.len(), 1); let expr = &row_id_transform.exprs[0]; assert_eq!(expr, &expeceted_expr);