-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix: increase ROUND decimal precision to prevent overflow truncation #19926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,17 +27,19 @@ use arrow::datatypes::{ | |
| ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, | ||
| Decimal256Type, Float32Type, Float64Type, Int32Type, | ||
| }; | ||
| use arrow::datatypes::{Field, FieldRef}; | ||
| use arrow::error::ArrowError; | ||
| use datafusion_common::types::{ | ||
| NativeType, logical_float32, logical_float64, logical_int32, | ||
| }; | ||
| use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; | ||
| use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; | ||
| use datafusion_expr::{ | ||
| Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, | ||
| TypeSignature, TypeSignatureClass, Volatility, | ||
| Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, | ||
| ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, | ||
| }; | ||
| use datafusion_macros::user_doc; | ||
| use std::sync::Arc; | ||
|
|
||
| #[user_doc( | ||
| doc_section(label = "Math Functions"), | ||
|
|
@@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc { | |
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(match arg_types[0].clone() { | ||
| fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> { | ||
| let input_field = &args.arg_fields[0]; | ||
| let input_type = input_field.data_type(); | ||
|
|
||
| // Get decimal_places from scalar_arguments | ||
| // If dp is not a constant scalar, we must keep the original scale because | ||
| // we can't determine a single output scale for varying per-row dp values. | ||
| let (decimal_places, dp_is_scalar): (i32, bool) = | ||
| if args.scalar_arguments.len() > 1 { | ||
| match args.scalar_arguments[1] { | ||
| Some(ScalarValue::Int32(Some(v))) => (*v, true), | ||
| Some(ScalarValue::Int64(Some(v))) => (*v as i32, true), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better use a safe cast and return an Err if the number is bigger than i32::MAX (or even i8::MAX) |
||
| _ => (0, false), // dp is a column or null - can't determine scale | ||
| } | ||
| } else { | ||
| (0, true) // No dp argument means default to 0 | ||
| }; | ||
|
|
||
| // Calculate return type based on input type | ||
| // For decimals: reduce scale to decimal_places (reclaims precision for integer part) | ||
| // This matches Spark/DuckDB behavior where ROUND adjusts the scale | ||
| // BUT only if dp is a constant - otherwise keep original scale | ||
| let return_type = match input_type { | ||
| Float32 => Float32, | ||
| dt @ Decimal128(_, _) | ||
| | dt @ Decimal256(_, _) | ||
| | dt @ Decimal32(_, _) | ||
| | dt @ Decimal64(_, _) => dt, | ||
| Decimal32(precision, scale) => { | ||
| if dp_is_scalar { | ||
| let new_scale = (*scale).min(decimal_places.max(0) as i8); | ||
| Decimal32(*precision, new_scale) | ||
| } else { | ||
| Decimal32(*precision, *scale) | ||
| } | ||
| } | ||
| Decimal64(precision, scale) => { | ||
| if dp_is_scalar { | ||
| let new_scale = (*scale).min(decimal_places.max(0) as i8); | ||
| Decimal64(*precision, new_scale) | ||
| } else { | ||
| Decimal64(*precision, *scale) | ||
| } | ||
| } | ||
| Decimal128(precision, scale) => { | ||
| if dp_is_scalar { | ||
| let new_scale = (*scale).min(decimal_places.max(0) as i8); | ||
| Decimal128(*precision, new_scale) | ||
| } else { | ||
| Decimal128(*precision, *scale) | ||
| } | ||
| } | ||
| Decimal256(precision, scale) => { | ||
| if dp_is_scalar { | ||
| let new_scale = (*scale).min(decimal_places.max(0) as i8); | ||
| Decimal256(*precision, new_scale) | ||
| } else { | ||
| Decimal256(*precision, *scale) | ||
| } | ||
| } | ||
| _ => Float64, | ||
| }) | ||
| }; | ||
|
|
||
| Ok(Arc::new(Field::new( | ||
| self.name(), | ||
| return_type, | ||
| input_field.is_nullable(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should also take into account the Postgres: Apache Spark: DuckDB: |
||
| ))) | ||
| } | ||
|
|
||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| internal_err!("use return_field_from_args instead") | ||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
|
|
@@ -141,7 +202,6 @@ impl ScalarUDFImpl for RoundFunc { | |
| &default_decimal_places | ||
| }; | ||
|
|
||
| // Scalar fast path for float and decimal types - avoid array conversion overhead | ||
| if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = | ||
| (&args.args[0], decimal_places) | ||
| { | ||
|
|
@@ -169,27 +229,32 @@ impl ScalarUDFImpl for RoundFunc { | |
| Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) | ||
| } | ||
| ScalarValue::Decimal128(Some(v), precision, scale) => { | ||
| let rounded = round_decimal(*v, *scale, dp)?; | ||
| // Reduce scale to reclaim integer precision | ||
| let new_scale = (*scale).min(dp.max(0) as i8); | ||
| let rounded = round_decimal(*v, *scale, new_scale, dp)?; | ||
| let scalar = | ||
| ScalarValue::Decimal128(Some(rounded), *precision, *scale); | ||
| ScalarValue::Decimal128(Some(rounded), *precision, new_scale); | ||
| Ok(ColumnarValue::Scalar(scalar)) | ||
| } | ||
| ScalarValue::Decimal256(Some(v), precision, scale) => { | ||
| let rounded = round_decimal(*v, *scale, dp)?; | ||
| let new_scale = (*scale).min(dp.max(0) as i8); | ||
| let rounded = round_decimal(*v, *scale, new_scale, dp)?; | ||
| let scalar = | ||
| ScalarValue::Decimal256(Some(rounded), *precision, *scale); | ||
| ScalarValue::Decimal256(Some(rounded), *precision, new_scale); | ||
| Ok(ColumnarValue::Scalar(scalar)) | ||
| } | ||
| ScalarValue::Decimal64(Some(v), precision, scale) => { | ||
| let rounded = round_decimal(*v, *scale, dp)?; | ||
| let new_scale = (*scale).min(dp.max(0) as i8); | ||
| let rounded = round_decimal(*v, *scale, new_scale, dp)?; | ||
| let scalar = | ||
| ScalarValue::Decimal64(Some(rounded), *precision, *scale); | ||
| ScalarValue::Decimal64(Some(rounded), *precision, new_scale); | ||
| Ok(ColumnarValue::Scalar(scalar)) | ||
| } | ||
| ScalarValue::Decimal32(Some(v), precision, scale) => { | ||
| let rounded = round_decimal(*v, *scale, dp)?; | ||
| let new_scale = (*scale).min(dp.max(0) as i8); | ||
| let rounded = round_decimal(*v, *scale, new_scale, dp)?; | ||
| let scalar = | ||
| ScalarValue::Decimal32(Some(rounded), *precision, *scale); | ||
| ScalarValue::Decimal32(Some(rounded), *precision, new_scale); | ||
| Ok(ColumnarValue::Scalar(scalar)) | ||
| } | ||
| _ => { | ||
|
|
@@ -200,7 +265,12 @@ impl ScalarUDFImpl for RoundFunc { | |
| } | ||
| } | ||
| } else { | ||
| round_columnar(&args.args[0], decimal_places, args.number_rows) | ||
| round_columnar( | ||
| &args.args[0], | ||
| decimal_places, | ||
| args.number_rows, | ||
| args.return_type(), | ||
| ) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -228,29 +298,31 @@ fn round_columnar( | |
| value: &ColumnarValue, | ||
| decimal_places: &ColumnarValue, | ||
| number_rows: usize, | ||
| return_type: &DataType, | ||
| ) -> Result<ColumnarValue> { | ||
| let value_array = value.to_array(number_rows)?; | ||
| let both_scalars = matches!(value, ColumnarValue::Scalar(_)) | ||
| && matches!(decimal_places, ColumnarValue::Scalar(_)); | ||
|
|
||
| let arr: ArrayRef = match value_array.data_type() { | ||
| Float64 => { | ||
| let arr: ArrayRef = match (value_array.data_type(), return_type) { | ||
| (Float64, _) => { | ||
| let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| round_float::<f64>, | ||
| )?; | ||
| result as _ | ||
| } | ||
| Float32 => { | ||
| (Float32, _) => { | ||
| let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| round_float::<f32>, | ||
| )?; | ||
| result as _ | ||
| } | ||
| Decimal32(precision, scale) => { | ||
| (Decimal32(_, scale), Decimal32(precision, new_scale)) => { | ||
| // reduce scale to reclaim integer precision | ||
| let result = calculate_binary_decimal_math::< | ||
| Decimal32Type, | ||
| Int32Type, | ||
|
|
@@ -259,13 +331,13 @@ fn round_columnar( | |
| >( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| |v, dp| round_decimal(v, *scale, dp), | ||
| |v, dp| round_decimal(v, *scale, *new_scale, dp), | ||
| *precision, | ||
| *scale, | ||
| *new_scale, | ||
| )?; | ||
| result as _ | ||
| } | ||
| Decimal64(precision, scale) => { | ||
| (Decimal64(_, scale), Decimal64(precision, new_scale)) => { | ||
| let result = calculate_binary_decimal_math::< | ||
| Decimal64Type, | ||
| Int32Type, | ||
|
|
@@ -274,13 +346,13 @@ fn round_columnar( | |
| >( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| |v, dp| round_decimal(v, *scale, dp), | ||
| |v, dp| round_decimal(v, *scale, *new_scale, dp), | ||
| *precision, | ||
| *scale, | ||
| *new_scale, | ||
| )?; | ||
| result as _ | ||
| } | ||
| Decimal128(precision, scale) => { | ||
| (Decimal128(_, scale), Decimal128(precision, new_scale)) => { | ||
| let result = calculate_binary_decimal_math::< | ||
| Decimal128Type, | ||
| Int32Type, | ||
|
|
@@ -289,13 +361,13 @@ fn round_columnar( | |
| >( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| |v, dp| round_decimal(v, *scale, dp), | ||
| |v, dp| round_decimal(v, *scale, *new_scale, dp), | ||
| *precision, | ||
| *scale, | ||
| *new_scale, | ||
| )?; | ||
| result as _ | ||
| } | ||
| Decimal256(precision, scale) => { | ||
| (Decimal256(_, scale), Decimal256(precision, new_scale)) => { | ||
| let result = calculate_binary_decimal_math::< | ||
| Decimal256Type, | ||
| Int32Type, | ||
|
|
@@ -304,13 +376,13 @@ fn round_columnar( | |
| >( | ||
| value_array.as_ref(), | ||
| decimal_places, | ||
| |v, dp| round_decimal(v, *scale, dp), | ||
| |v, dp| round_decimal(v, *scale, *new_scale, dp), | ||
| *precision, | ||
| *scale, | ||
| *new_scale, | ||
| )?; | ||
| result as _ | ||
| } | ||
| other => exec_err!("Unsupported data type {other:?} for function round")?, | ||
| (other, _) => exec_err!("Unsupported data type {other:?} for function round")?, | ||
| }; | ||
|
|
||
| if both_scalars { | ||
|
|
@@ -334,10 +406,11 @@ where | |
|
|
||
| fn round_decimal<V: ArrowNativeTypeOp>( | ||
| value: V, | ||
| scale: i8, | ||
| input_scale: i8, | ||
| output_scale: i8, | ||
| decimal_places: i32, | ||
| ) -> Result<V, ArrowError> { | ||
| let diff = i64::from(scale) - i64::from(decimal_places); | ||
| let diff = i64::from(input_scale) - i64::from(decimal_places); | ||
| if diff <= 0 { | ||
| return Ok(value); | ||
| } | ||
|
|
@@ -358,7 +431,7 @@ fn round_decimal<V: ArrowNativeTypeOp>( | |
|
|
||
| let factor = ten.pow_checked(diff).map_err(|_| { | ||
| ArrowError::ComputeError(format!( | ||
| "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}" | ||
| "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" | ||
| )) | ||
| })?; | ||
|
|
||
|
|
@@ -377,9 +450,40 @@ fn round_decimal<V: ArrowNativeTypeOp>( | |
| })?; | ||
| } | ||
|
|
||
| quotient | ||
| .mul_checked(factor) | ||
| .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) | ||
| // Determine how to scale the result based on output_scale vs computed scale | ||
| // computed_scale = max(0, min(input_scale, decimal_places)) | ||
| let computed_scale = if decimal_places >= 0 { | ||
| (input_scale as i32).min(decimal_places).max(0) as i8 | ||
| } else { | ||
| 0 | ||
| }; | ||
|
|
||
| if output_scale == computed_scale { | ||
| // scale reduction, return quotient directly (or shifted for negative dp) | ||
| if decimal_places >= 0 { | ||
| Ok(quotient) | ||
| } else { | ||
| // For negative decimal_places, multiply by 10^(-decimal_places) to shift left | ||
| let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| { | ||
| ArrowError::ComputeError(format!( | ||
| "Invalid negative decimal places: {decimal_places}" | ||
| )) | ||
| })?; | ||
| let shift_factor = ten.pow_checked(neg_dp).map_err(|_| { | ||
| ArrowError::ComputeError(format!( | ||
| "Overflow computing shift factor for decimal places {decimal_places}" | ||
| )) | ||
| })?; | ||
| quotient.mul_checked(shift_factor).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding decimal".into()) | ||
| }) | ||
| } | ||
| } else { | ||
| // Keep original scale behavior: multiply back by factor | ||
| quotient.mul_checked(factor).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding decimal".into()) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
|
|
@@ -397,12 +501,14 @@ mod test { | |
| decimal_places: Option<ArrayRef>, | ||
| ) -> Result<ArrayRef, DataFusionError> { | ||
| let number_rows = value.len(); | ||
| let return_type = value.data_type().clone(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be wrong for decimals which scale is reduced. |
||
| let value = ColumnarValue::Array(value); | ||
| let decimal_places = decimal_places | ||
| .map(ColumnarValue::Array) | ||
| .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))); | ||
|
|
||
| let result = super::round_columnar(&value, &decimal_places, number_rows)?; | ||
| let result = | ||
| super::round_columnar(&value, &decimal_places, number_rows, &return_type)?; | ||
| match result { | ||
| ColumnarValue::Array(array) => Ok(array), | ||
| ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here
decimal_placesis initialized asi32but below it is casted toi8withas i8and this may lead to problems. A validation is needed that it is not bigger than i8::MAX