Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 43 additions & 51 deletions datafusion/functions/src/math/factorial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use std::sync::Arc;
use arrow::datatypes::DataType::Int64;
use arrow::datatypes::{DataType, Int64Type};

use crate::utils::make_scalar_function;
use datafusion_common::{Result, exec_err};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
Expand Down Expand Up @@ -81,7 +82,39 @@ impl ScalarUDFImpl for FactorialFunc {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(factorial, vec![])(&args.args)
let [arg] = take_function_args(self.name(), args.args)?;

match arg {
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None)));
}

match scalar {
ScalarValue::Int64(Some(v)) => {
let result = compute_factorial(v)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
}
_ => {
internal_err!(
"Unexpected data type {:?} for function factorial",
scalar.data_type()
)
}
}
}
ColumnarValue::Array(array) => match array.data_type() {
Int64 => {
let result: Int64Array = array
.as_primitive::<Int64Type>()
.try_unary(compute_factorial)?;
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
other => {
internal_err!("Unexpected data type {other:?} for function factorial")
}
},
}
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down Expand Up @@ -113,53 +146,12 @@ const FACTORIALS: [i64; 21] = [
2432902008176640000,
]; // if return type changes, this constant needs to be updated accordingly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to keep this comment


/// Factorial SQL function
fn factorial(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
Int64 => {
let result: Int64Array =
args[0].as_primitive::<Int64Type>().try_unary(|a| {
if a < 0 {
Ok(1)
} else if a < FACTORIALS.len() as i64 {
Ok(FACTORIALS[a as usize])
} else {
exec_err!("Overflow happened on FACTORIAL({a})")
}
})?;
Ok(Arc::new(result) as ArrayRef)
}
other => exec_err!("Unsupported data type {other:?} for function factorial."),
}
}

#[cfg(test)]
mod test {
use super::*;
use datafusion_common::cast::as_int64_array;

#[test]
fn test_factorial_i64() {
let args: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![0, 1, 2, 4, 20, -1])), // input
];

let result = factorial(&args).expect("failed to initialize function factorial");
let ints =
as_int64_array(&result).expect("failed to initialize function factorial");

let expected = Int64Array::from(vec![1, 1, 2, 24, 2432902008176640000, 1]);

assert_eq!(ints, &expected);
}

#[test]
fn test_overflow() {
let args: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![21])), // input
];

let result = factorial(&args);
assert!(result.is_err());
fn compute_factorial(n: i64) -> Result<i64> {
if n < 0 {
Ok(1)
} else if n < FACTORIALS.len() as i64 {
Ok(FACTORIALS[n as usize])
} else {
exec_err!("Overflow happened on FACTORIAL({n})")
}
}