Skip to content

Commit 9f27e93

Browse files
authored
Coerce expressions to udtf (#19915)
## Which issue does this PR close? Closes #19914 The changes are fairly simple. The bug only occurs to udtf, so I added a test case for this.
1 parent bfe7d18 commit 9f27e93

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

datafusion/core/src/execution/session_state.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1843,9 +1843,14 @@ impl ContextProvider for SessionContextProvider<'_> {
18431843
self.state.execution_props().query_execution_start_time,
18441844
);
18451845
let simplifier = ExprSimplifier::new(simplify_context);
1846+
let schema = DFSchema::empty();
18461847
let args = args
18471848
.into_iter()
1848-
.map(|arg| simplifier.simplify(arg))
1849+
.map(|arg| {
1850+
simplifier
1851+
.coerce(arg, &schema)
1852+
.and_then(|e| simplifier.simplify(e))
1853+
})
18491854
.collect::<datafusion_common::Result<Vec<_>>>()?;
18501855
let provider = tbl_func.create_table_provider(&args)?;
18511856

datafusion/core/tests/user_defined/user_defined_table_functions.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,31 @@ impl TableFunctionImpl for SimpleCsvTableFunc {
221221
}
222222
}
223223

224+
/// Test that expressions passed to UDTFs are properly type-coerced
225+
/// This is a regression test for https://github.com/apache/datafusion/issues/19914
226+
#[tokio::test]
227+
async fn test_udtf_type_coercion() -> Result<()> {
228+
use datafusion::datasource::MemTable;
229+
230+
#[derive(Debug)]
231+
struct NoOpTableFunc;
232+
233+
impl TableFunctionImpl for NoOpTableFunc {
234+
fn call(&self, _: &[Expr]) -> Result<Arc<dyn TableProvider>> {
235+
let schema = Arc::new(arrow::datatypes::Schema::empty());
236+
Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?))
237+
}
238+
}
239+
240+
let ctx = SessionContext::new();
241+
ctx.register_udtf("f", Arc::new(NoOpTableFunc));
242+
243+
// This should not panic - the array elements should be coerced to Float64
244+
let _ = ctx.sql("SELECT * FROM f(ARRAY[0.1, 1, 2])").await?;
245+
246+
Ok(())
247+
}
248+
224249
fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
225250
let mut file = File::open(csv_path)?;
226251
let (schema, _) = Format::default()

0 commit comments

Comments
 (0)