-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix struct casts to align fields by name (prevent positional mis-casts) #19674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 59 commits
42fe863
6b7ce25
edffe39
77da244
9dc6f77
0544307
4075920
9f04a4e
67d2659
129c9f7
e337ef7
b0ed1ab
cad6eac
9b61e2f
b22a742
8682073
cc926b3
0fe71b3
d0f1cc0
c39e9eb
5ef5a12
49ee3a6
de57ca9
32065fa
c76f1a6
f0d43c4
3bc5444
96b7f5f
e4ae1bd
7eb379a
2f15474
0897435
f801f19
aa04ded
6558d69
bdf5f20
682d28a
63e80f6
0d5126c
81e56fb
f304499
5f08346
c74bddd
2734da2
b9d79fc
77a7f94
914356f
af8e034
20d1c38
e1eff12
ec26c5e
f557de2
11e8745
da60ed9
7652dbb
ea748d3
f5642db
9f1d56e
5a4ecb1
2d5457b
9f496dd
352066c
e5d5d7a
8c0fe98
5e526b4
20d0248
fba0610
1d6a52c
d3af52c
8bdaa84
3d3c99f
68819ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,9 @@ use crate::error::{_plan_err, Result}; | |
| use arrow::{ | ||
| array::{Array, ArrayRef, StructArray, new_null_array}, | ||
| compute::{CastOptions, cast_with_options}, | ||
| datatypes::{DataType::Struct, Field, FieldRef}, | ||
| datatypes::{DataType, DataType::Struct, Field, FieldRef}, | ||
| }; | ||
| use std::sync::Arc; | ||
| use std::{collections::HashSet, sync::Arc}; | ||
|
|
||
| /// Cast a struct column to match target struct fields, handling nested structs recursively. | ||
| /// | ||
|
|
@@ -31,6 +31,7 @@ use std::sync::Arc; | |
| /// | ||
| /// ## Field Matching Strategy | ||
| /// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) | ||
| /// - **By Position**: When there is no name overlap and the field counts match, fields are cast by index | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's worth clarifying we're going to remove this behavior if we indeed plan on doing so. Basically why not deprecate it as best we can in this PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove positional casting after this PR. |
||
| /// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type | ||
| /// - **Missing Fields**: Target fields not present in the source are filled with null values | ||
| /// - **Extra Fields**: Source fields not present in the target are ignored | ||
|
|
@@ -54,16 +55,38 @@ fn cast_struct_column( | |
| target_fields: &[Arc<Field>], | ||
| cast_options: &CastOptions, | ||
| ) -> Result<ArrayRef> { | ||
| if source_col.data_type() == &DataType::Null | ||
| || (!source_col.is_empty() && source_col.null_count() == source_col.len()) | ||
| { | ||
| return Ok(new_null_array( | ||
| &Struct(target_fields.to_vec().into()), | ||
| source_col.len(), | ||
| )); | ||
| } | ||
|
Comment on lines
+58
to
+65
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unnecessary if called from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. amended |
||
|
|
||
| if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() { | ||
| validate_struct_compatibility(source_struct.fields(), target_fields)?; | ||
| let source_fields = source_struct.fields(); | ||
| let has_overlap = fields_have_name_overlap(source_fields, target_fields); | ||
| validate_struct_compatibility(source_fields, target_fields)?; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this function need to be adjusted at all to take mapping of fields by name into account?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No change needed. I added more tests, though - 9f496dd |
||
|
|
||
| let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len()); | ||
| let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len()); | ||
| let num_rows = source_col.len(); | ||
|
|
||
| for target_child_field in target_fields { | ||
| // Iterate target fields and pick source child either by name (when fields overlap) | ||
| // or by position (when there is no name overlap). | ||
| for (index, target_child_field) in target_fields.iter().enumerate() { | ||
| fields.push(Arc::clone(target_child_field)); | ||
| match source_struct.column_by_name(target_child_field.name()) { | ||
|
|
||
| // Determine the source child column: by name when overlapping names exist, | ||
| // otherwise by position. | ||
| let source_child_opt: Option<&ArrayRef> = if has_overlap { | ||
| source_struct.column_by_name(target_child_field.name()) | ||
| } else { | ||
| Some(source_struct.column(index)) | ||
| }; | ||
|
|
||
| match source_child_opt { | ||
| Some(source_child_col) => { | ||
| let adapted_child = | ||
| cast_column(source_child_col, target_child_field, cast_options) | ||
|
|
@@ -204,51 +227,114 @@ pub fn validate_struct_compatibility( | |
| source_fields: &[FieldRef], | ||
| target_fields: &[FieldRef], | ||
| ) -> Result<()> { | ||
| let has_overlap = fields_have_name_overlap(source_fields, target_fields); | ||
| if !has_overlap { | ||
| if source_fields.len() != target_fields.len() { | ||
| return _plan_err!( | ||
| "Cannot cast struct with {} fields to {} fields without name overlap; positional mapping is ambiguous", | ||
| source_fields.len(), | ||
| target_fields.len() | ||
| ); | ||
| } | ||
|
|
||
| for (source_field, target_field) in source_fields.iter().zip(target_fields.iter()) | ||
| { | ||
| validate_field_compatibility(source_field, target_field)?; | ||
| } | ||
|
|
||
| return Ok(()); | ||
| } | ||
|
|
||
| // Check compatibility for each target field | ||
| for target_field in target_fields { | ||
| // Look for matching field in source by name | ||
| if let Some(source_field) = source_fields | ||
| .iter() | ||
| .find(|f| f.name() == target_field.name()) | ||
| { | ||
| // Ensure nullability is compatible. It is invalid to cast a nullable | ||
| // source field to a non-nullable target field as this may discard | ||
| // null values. | ||
| if source_field.is_nullable() && !target_field.is_nullable() { | ||
| validate_field_compatibility(source_field, target_field)?; | ||
| } else { | ||
| // Target field is missing from source | ||
| // If it's non-nullable, we cannot fill it with NULL | ||
| if !target_field.is_nullable() { | ||
| return _plan_err!( | ||
| "Cannot cast nullable struct field '{}' to non-nullable field", | ||
| "Cannot cast struct: target field '{}' is non-nullable but missing from source. \ | ||
| Cannot fill with NULL.", | ||
| target_field.name() | ||
| ); | ||
| } | ||
| // Check if the matching field types are compatible | ||
| match (source_field.data_type(), target_field.data_type()) { | ||
| // Recursively validate nested structs | ||
| (Struct(source_nested), Struct(target_nested)) => { | ||
| validate_struct_compatibility(source_nested, target_nested)?; | ||
| } | ||
| // For non-struct types, use the existing castability check | ||
| _ => { | ||
| if !arrow::compute::can_cast_types( | ||
| source_field.data_type(), | ||
| target_field.data_type(), | ||
| ) { | ||
| return _plan_err!( | ||
| "Cannot cast struct field '{}' from type {} to type {}", | ||
| target_field.name(), | ||
| source_field.data_type(), | ||
| target_field.data_type() | ||
| ); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| // Missing fields in source are OK - they'll be filled with nulls | ||
| } | ||
|
|
||
| // Extra fields in source are OK - they'll be ignored | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn validate_field_compatibility( | ||
| source_field: &Field, | ||
| target_field: &Field, | ||
| ) -> Result<()> { | ||
| if source_field.data_type() == &DataType::Null { | ||
| // Validate that target allows nulls before returning early. | ||
| // It is invalid to cast a NULL source field to a non-nullable target field. | ||
| if !target_field.is_nullable() { | ||
| return _plan_err!( | ||
| "Cannot cast NULL struct field '{}' to non-nullable field '{}'", | ||
| source_field.name(), | ||
| target_field.name() | ||
| ); | ||
| } | ||
| return Ok(()); | ||
| } | ||
|
|
||
| // Ensure nullability is compatible. It is invalid to cast a nullable | ||
| // source field to a non-nullable target field as this may discard | ||
| // null values. | ||
| if source_field.is_nullable() && !target_field.is_nullable() { | ||
| return _plan_err!( | ||
| "Cannot cast nullable struct field '{}' to non-nullable field", | ||
| target_field.name() | ||
| ); | ||
| } | ||
|
|
||
| // Check if the matching field types are compatible | ||
| match (source_field.data_type(), target_field.data_type()) { | ||
| // Recursively validate nested structs | ||
| (Struct(source_nested), Struct(target_nested)) => { | ||
| validate_struct_compatibility(source_nested, target_nested)?; | ||
| } | ||
| // For non-struct types, use the existing castability check | ||
| _ => { | ||
| if !arrow::compute::can_cast_types( | ||
| source_field.data_type(), | ||
| target_field.data_type(), | ||
| ) { | ||
| return _plan_err!( | ||
| "Cannot cast struct field '{}' from type {} to type {}", | ||
| target_field.name(), | ||
| source_field.data_type(), | ||
| target_field.data_type() | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn fields_have_name_overlap( | ||
|
||
| source_fields: &[FieldRef], | ||
| target_fields: &[FieldRef], | ||
| ) -> bool { | ||
| let source_names: HashSet<&str> = source_fields | ||
| .iter() | ||
| .map(|field| field.name().as_str()) | ||
| .collect(); | ||
| target_fields | ||
| .iter() | ||
| .any(|field| source_names.contains(field.name().as_str())) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
|
|
||
|
|
@@ -257,7 +343,7 @@ mod tests { | |
| use arrow::{ | ||
| array::{ | ||
| BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, | ||
| MapBuilder, StringArray, StringBuilder, | ||
| MapBuilder, NullArray, StringArray, StringBuilder, | ||
| }, | ||
| buffer::NullBuffer, | ||
| datatypes::{DataType, Field, FieldRef, Int32Type}, | ||
|
|
@@ -428,11 +514,14 @@ mod tests { | |
|
|
||
| #[test] | ||
| fn test_validate_struct_compatibility_missing_field_in_source() { | ||
| // Source struct: {field2: String} (missing field1) | ||
| let source_fields = vec![arc_field("field2", DataType::Utf8)]; | ||
| // Source struct: {field1: Int32} (missing field2) | ||
| let source_fields = vec![arc_field("field1", DataType::Int32)]; | ||
|
|
||
| // Target struct: {field1: Int32} | ||
| let target_fields = vec![arc_field("field1", DataType::Int32)]; | ||
| // Target struct: {field1: Int32, field2: Utf8} | ||
| let target_fields = vec![ | ||
| arc_field("field1", DataType::Int32), | ||
| arc_field("field2", DataType::Utf8), | ||
| ]; | ||
|
|
||
| // Should be OK - missing fields will be filled with nulls | ||
| let result = validate_struct_compatibility(&source_fields, &target_fields); | ||
|
|
@@ -455,6 +544,20 @@ mod tests { | |
| assert!(result.is_ok()); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_validate_struct_compatibility_positional_no_overlap_mismatch_len() { | ||
| let source_fields = vec![ | ||
| arc_field("left", DataType::Int32), | ||
| arc_field("right", DataType::Int32), | ||
| ]; | ||
| let target_fields = vec![arc_field("alpha", DataType::Int32)]; | ||
|
|
||
| let result = validate_struct_compatibility(&source_fields, &target_fields); | ||
| assert!(result.is_err()); | ||
| let error_msg = result.unwrap_err().to_string(); | ||
| assert!(error_msg.contains("positional mapping is ambiguous")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_struct_parent_nulls_retained() { | ||
| let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; | ||
|
|
@@ -585,6 +688,33 @@ mod tests { | |
| assert!(missing.is_null(1)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_null_struct_field_to_nested_struct() { | ||
| let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; | ||
| let source_struct = StructArray::from(vec![( | ||
| arc_field("inner", DataType::Null), | ||
| Arc::clone(&null_inner), | ||
| )]); | ||
| let source_col = Arc::new(source_struct) as ArrayRef; | ||
|
|
||
| let target_field = struct_field( | ||
| "outer", | ||
| vec![struct_field("inner", vec![field("a", DataType::Int32)])], | ||
| ); | ||
|
|
||
| let result = | ||
| cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); | ||
| let outer = result.as_any().downcast_ref::<StructArray>().unwrap(); | ||
| let inner = get_column_as!(&outer, "inner", StructArray); | ||
| assert_eq!(inner.len(), 2); | ||
| assert!(inner.is_null(0)); | ||
| assert!(inner.is_null(1)); | ||
|
|
||
| let inner_a = get_column_as!(inner, "a", Int32Array); | ||
| assert!(inner_a.is_null(0)); | ||
| assert!(inner_a.is_null(1)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_struct_with_array_and_map_fields() { | ||
| // Array field with second row null | ||
|
|
@@ -704,4 +834,88 @@ mod tests { | |
| assert_eq!(a_col.value(0), 1); | ||
| assert_eq!(a_col.value(1), 2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_struct_positional_when_no_overlap() { | ||
| let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; | ||
| let second = | ||
| Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; | ||
|
|
||
| let source_struct = StructArray::from(vec![ | ||
| (arc_field("left", DataType::Int32), first), | ||
| (arc_field("right", DataType::Utf8), second), | ||
| ]); | ||
| let source_col = Arc::new(source_struct) as ArrayRef; | ||
|
|
||
| let target_field = struct_field( | ||
| "s", | ||
| vec![field("a", DataType::Int64), field("b", DataType::Utf8)], | ||
| ); | ||
|
|
||
| let result = | ||
| cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); | ||
| let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap(); | ||
|
|
||
| let a_col = get_column_as!(&struct_array, "a", Int64Array); | ||
| assert_eq!(a_col.value(0), 10); | ||
| assert_eq!(a_col.value(1), 20); | ||
|
|
||
| let b_col = get_column_as!(&struct_array, "b", StringArray); | ||
| assert_eq!(b_col.value(0), "alpha"); | ||
| assert_eq!(b_col.value(1), "beta"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_struct_missing_non_nullable_field_fails() { | ||
| // Source has only field 'a' | ||
| let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; | ||
| let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); | ||
| let source_col = Arc::new(source_struct) as ArrayRef; | ||
|
|
||
| // Target has fields 'a' (nullable) and 'b' (non-nullable) | ||
| let target_field = struct_field( | ||
| "s", | ||
| vec![ | ||
| field("a", DataType::Int32), | ||
| non_null_field("b", DataType::Int32), | ||
| ], | ||
| ); | ||
|
|
||
| // Should fail because 'b' is non-nullable but missing from source | ||
| let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); | ||
| assert!(result.is_err()); | ||
| let err = result.unwrap_err(); | ||
| assert!( | ||
| err.to_string() | ||
| .contains("target field 'b' is non-nullable but missing from source"), | ||
| "Unexpected error: {err}" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_struct_missing_nullable_field_succeeds() { | ||
| // Source has only field 'a' | ||
| let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; | ||
| let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); | ||
| let source_col = Arc::new(source_struct) as ArrayRef; | ||
|
|
||
| // Target has fields 'a' and 'b' (both nullable) | ||
| let target_field = struct_field( | ||
| "s", | ||
| vec![field("a", DataType::Int32), field("b", DataType::Int32)], | ||
| ); | ||
|
|
||
| // Should succeed - 'b' is nullable so can be filled with NULL | ||
| let result = | ||
| cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); | ||
| let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap(); | ||
|
|
||
| let a_col = get_column_as!(&struct_array, "a", Int32Array); | ||
| assert_eq!(a_col.value(0), 1); | ||
| assert_eq!(a_col.value(1), 2); | ||
|
|
||
| let b_col = get_column_as!(&struct_array, "b", Int32Array); | ||
| assert!(b_col.is_null(0)); | ||
| assert!(b_col.is_null(1)); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR.
Fix to enable CI to pass.