Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 148 additions & 42 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, Float32Type, Float64Type, Int32Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Math Functions"),
Expand Down Expand Up @@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(match arg_types[0].clone() {
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let input_field = &args.arg_fields[0];
let input_type = input_field.data_type();

// Get decimal_places from scalar_arguments
// If dp is not a constant scalar, we must keep the original scale because
// we can't determine a single output scale for varying per-row dp values.
let (decimal_places, dp_is_scalar): (i32, bool) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here decimal_places is initialized as i32 but below it is casted to i8 with as i8 and this may lead to problems. A validation is needed that it is not bigger than i8::MAX

if args.scalar_arguments.len() > 1 {
match args.scalar_arguments[1] {
Some(ScalarValue::Int32(Some(v))) => (*v, true),
Some(ScalarValue::Int64(Some(v))) => (*v as i32, true),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use a safe cast and return an Err if the number is bigger than i32::MAX (or even i8::MAX)

_ => (0, false), // dp is a column or null - can't determine scale
}
} else {
(0, true) // No dp argument means default to 0
};

// Calculate return type based on input type
// For decimals: reduce scale to decimal_places (reclaims precision for integer part)
// This matches Spark/DuckDB behavior where ROUND adjusts the scale
// BUT only if dp is a constant - otherwise keep original scale
let return_type = match input_type {
Float32 => Float32,
dt @ Decimal128(_, _)
| dt @ Decimal256(_, _)
| dt @ Decimal32(_, _)
| dt @ Decimal64(_, _) => dt,
Decimal32(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal32(*precision, new_scale)
} else {
Decimal32(*precision, *scale)
}
}
Decimal64(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal64(*precision, new_scale)
} else {
Decimal64(*precision, *scale)
}
}
Decimal128(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal128(*precision, new_scale)
} else {
Decimal128(*precision, *scale)
}
}
Decimal256(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal256(*precision, new_scale)
} else {
Decimal256(*precision, *scale)
}
}
_ => Float64,
})
};

Ok(Arc::new(Field::new(
self.name(),
return_type,
input_field.is_nullable(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also take into account the decimal_places arg.

Postgres:

postgres=# SELECT pg_typeof(round(999.9::DECIMAL(4,1))), round(999.9::DECIMAL(4,1), NULL);
 pg_typeof | round
-----------+-------
 numeric   |
(1 row)

Apache Spark:

spark-sql (default)> SELECT typeof(round(999.9::DECIMAL(4,1))), round(999.9::DECIMAL(4,1), NULL);
decimal(4,0)    NULL
Time taken: 0.055 seconds, Fetched 1 row(s)

DuckDB:

D SELECT typeof(round(999.9::DECIMAL(4,1))), round(999.9::DECIMAL(4,1), NULL);
┌────────────────────────────────────────────┬──────────────────────────────────────────┐
│ typeof(round(CAST(999.9 AS DECIMAL(4,1)))) │ round(CAST(999.9 AS DECIMAL(4,1)), NULL) │
│                  varchar                   │                  int32                   │
├────────────────────────────────────────────┼──────────────────────────────────────────┤
│ DECIMAL(4,0)                               │                   NULL                   │
└────────────────────────────────────────────┴──────────────────────────────────────────┘

)))
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("use return_field_from_args instead")
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand All @@ -141,7 +202,6 @@ impl ScalarUDFImpl for RoundFunc {
&default_decimal_places
};

// Scalar fast path for float and decimal types - avoid array conversion overhead
if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
(&args.args[0], decimal_places)
{
Expand Down Expand Up @@ -169,27 +229,32 @@ impl ScalarUDFImpl for RoundFunc {
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
// Reduce scale to reclaim integer precision
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal128(Some(rounded), *precision, *scale);
ScalarValue::Decimal128(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal256(Some(rounded), *precision, *scale);
ScalarValue::Decimal256(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal64(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal64(Some(rounded), *precision, *scale);
ScalarValue::Decimal64(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal32(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal32(Some(rounded), *precision, *scale);
ScalarValue::Decimal32(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
_ => {
Expand All @@ -200,7 +265,12 @@ impl ScalarUDFImpl for RoundFunc {
}
}
} else {
round_columnar(&args.args[0], decimal_places, args.number_rows)
round_columnar(
&args.args[0],
decimal_places,
args.number_rows,
args.return_type(),
)
}
}

Expand Down Expand Up @@ -228,29 +298,31 @@ fn round_columnar(
value: &ColumnarValue,
decimal_places: &ColumnarValue,
number_rows: usize,
return_type: &DataType,
) -> Result<ColumnarValue> {
let value_array = value.to_array(number_rows)?;
let both_scalars = matches!(value, ColumnarValue::Scalar(_))
&& matches!(decimal_places, ColumnarValue::Scalar(_));

let arr: ArrayRef = match value_array.data_type() {
Float64 => {
let arr: ArrayRef = match (value_array.data_type(), return_type) {
(Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
value_array.as_ref(),
decimal_places,
round_float::<f64>,
)?;
result as _
}
Float32 => {
(Float32, _) => {
let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
value_array.as_ref(),
decimal_places,
round_float::<f32>,
)?;
result as _
}
Decimal32(precision, scale) => {
(Decimal32(_, scale), Decimal32(precision, new_scale)) => {
// reduce scale to reclaim integer precision
let result = calculate_binary_decimal_math::<
Decimal32Type,
Int32Type,
Expand All @@ -259,13 +331,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal64(precision, scale) => {
(Decimal64(_, scale), Decimal64(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal64Type,
Int32Type,
Expand All @@ -274,13 +346,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal128(precision, scale) => {
(Decimal128(_, scale), Decimal128(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal128Type,
Int32Type,
Expand All @@ -289,13 +361,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal256(precision, scale) => {
(Decimal256(_, scale), Decimal256(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal256Type,
Int32Type,
Expand All @@ -304,13 +376,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
other => exec_err!("Unsupported data type {other:?} for function round")?,
(other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
};

if both_scalars {
Expand All @@ -334,10 +406,11 @@ where

fn round_decimal<V: ArrowNativeTypeOp>(
value: V,
scale: i8,
input_scale: i8,
output_scale: i8,
decimal_places: i32,
) -> Result<V, ArrowError> {
let diff = i64::from(scale) - i64::from(decimal_places);
let diff = i64::from(input_scale) - i64::from(decimal_places);
if diff <= 0 {
return Ok(value);
}
Expand All @@ -358,7 +431,7 @@ fn round_decimal<V: ArrowNativeTypeOp>(

let factor = ten.pow_checked(diff).map_err(|_| {
ArrowError::ComputeError(format!(
"Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}"
"Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
))
})?;

Expand All @@ -377,9 +450,40 @@ fn round_decimal<V: ArrowNativeTypeOp>(
})?;
}

quotient
.mul_checked(factor)
.map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
// Determine how to scale the result based on output_scale vs computed scale
// computed_scale = max(0, min(input_scale, decimal_places))
let computed_scale = if decimal_places >= 0 {
(input_scale as i32).min(decimal_places).max(0) as i8
} else {
0
};

if output_scale == computed_scale {
// scale reduction, return quotient directly (or shifted for negative dp)
if decimal_places >= 0 {
Ok(quotient)
} else {
// For negative decimal_places, multiply by 10^(-decimal_places) to shift left
let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| {
ArrowError::ComputeError(format!(
"Invalid negative decimal places: {decimal_places}"
))
})?;
let shift_factor = ten.pow_checked(neg_dp).map_err(|_| {
ArrowError::ComputeError(format!(
"Overflow computing shift factor for decimal places {decimal_places}"
))
})?;
quotient.mul_checked(shift_factor).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding decimal".into())
})
}
} else {
// Keep original scale behavior: multiply back by factor
quotient.mul_checked(factor).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding decimal".into())
})
}
}

#[cfg(test)]
Expand All @@ -397,12 +501,14 @@ mod test {
decimal_places: Option<ArrayRef>,
) -> Result<ArrayRef, DataFusionError> {
let number_rows = value.len();
let return_type = value.data_type().clone();
Copy link
Member

@martin-g martin-g Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be wrong for decimals which scale is reduced.
No need to fix/improve it. Just a comment acknowledging that is enough.

let value = ColumnarValue::Array(value);
let decimal_places = decimal_places
.map(ColumnarValue::Array)
.unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));

let result = super::round_columnar(&value, &decimal_places, number_rows)?;
let result =
super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
match result {
ColumnarValue::Array(array) => Ok(array),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ query TR
select arrow_typeof(round(173975140545.855, 2)),
round(173975140545.855, 2);
----
Decimal128(15, 3) 173975140545.86
Decimal128(15, 2) 173975140545.86

# smoke test for decimal parsing
query RT
Expand Down
Loading