Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions datafusion/functions-nested/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
//! [`ScalarUDFImpl`] definitions for array_sort function.

use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, new_null_array,
};
use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_null_array};
use arrow::buffer::OffsetBuffer;
use arrow::compute::SortColumn;
use arrow::datatypes::{DataType, FieldRef};
use arrow::{compute, compute::SortOptions};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, plan_err};
use datafusion_common::{Result, exec_err};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
Expand Down Expand Up @@ -134,18 +132,7 @@ impl ScalarUDFImpl for ArraySort {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Null => Ok(DataType::Null),
DataType::List(field) => {
Ok(DataType::new_list(field.data_type().clone(), true))
}
DataType::LargeList(field) => {
Ok(DataType::new_large_list(field.data_type().clone(), true))
}
arg_type => {
plan_err!("{} does not support type {arg_type}", self.name())
}
}
Ok(arg_types[0].clone())
}

fn invoke_with_args(
Expand Down Expand Up @@ -206,11 +193,11 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
}
DataType::List(field) => {
let array = as_list_array(&args[0])?;
array_sort_generic(array, field, sort_options)
array_sort_generic(array, Arc::clone(field), sort_options)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&args[0])?;
array_sort_generic(array, field, sort_options)
array_sort_generic(array, Arc::clone(field), sort_options)
}
// Signature should prevent this arm ever occurring
_ => exec_err!("array_sort expects list for first argument"),
Expand All @@ -219,18 +206,16 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
field: &FieldRef,
field: FieldRef,
sort_options: Option<SortOptions>,
) -> Result<ArrayRef> {
let row_count = list_array.len();

let mut array_lengths = vec![];
let mut arrays = vec![];
let mut valid = NullBufferBuilder::new(row_count);
for i in 0..row_count {
if list_array.is_null(i) {
array_lengths.push(0);
valid.append_null();
} else {
let arr_ref = list_array.value(i);

Expand All @@ -253,25 +238,22 @@ fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
};
array_lengths.push(sorted_array.len());
arrays.push(sorted_array);
valid.append_non_null();
}
}

let buffer = valid.finish();

let elements = arrays
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = if elements.is_empty() {
GenericListArray::<OffsetSize>::new_null(Arc::clone(field), row_count)
GenericListArray::<OffsetSize>::new_null(field, row_count)
} else {
GenericListArray::<OffsetSize>::new(
Arc::clone(field),
field,
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
buffer,
list_array.nulls().cloned(),
)
};
Ok(Arc::new(list_arr))
Expand Down
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,31 @@ NULL NULL
NULL NULL
NULL NULL

# maintains inner nullability
query ?T
select array_sort(column1), arrow_typeof(array_sort(column1))
from values
(arrow_cast([], 'List(non-null Int32)')),
(arrow_cast(NULL, 'List(non-null Int32)')),
(arrow_cast([1, 3, 5, -5], 'List(non-null Int32)'))
;
----
[] List(non-null Int32)
NULL List(non-null Int32)
[-5, 1, 3, 5] List(non-null Int32)

query ?T
select column1, arrow_typeof(column1)
from values (array_sort(arrow_cast([1, 3, 5, -5], 'LargeList(non-null Int32)')));
----
[-5, 1, 3, 5] LargeList(non-null Int32)

query ?T
select column1, arrow_typeof(column1)
from values (array_sort(arrow_cast([1, 3, 5, -5], 'FixedSizeList(4 x non-null Int32)')));
----
[-5, 1, 3, 5] List(non-null Int32)

query ?
select array_sort([struct('foo', 3), struct('foo', 1), struct('bar', 1)])
----
Expand Down