diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs index 28116a4b5..c091c4517 100644 --- a/crates/iceberg/src/arrow/mod.rs +++ b/crates/iceberg/src/arrow/mod.rs @@ -35,4 +35,9 @@ mod value; pub use reader::*; pub use value::*; -pub(crate) mod record_batch_partition_splitter; +/// Partition value calculator for computing partition values +pub mod partition_value_calculator; +pub use partition_value_calculator::*; +/// Record batch partition splitter for partitioned tables +pub mod record_batch_partition_splitter; +pub use record_batch_partition_splitter::*; diff --git a/crates/iceberg/src/arrow/partition_value_calculator.rs b/crates/iceberg/src/arrow/partition_value_calculator.rs new file mode 100644 index 000000000..140950345 --- /dev/null +++ b/crates/iceberg/src/arrow/partition_value_calculator.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partition value calculation for Iceberg tables. +//! +//! This module provides utilities for calculating partition values from record batches +//! based on a partition specification. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch, StructArray}; +use arrow_schema::DataType; + +use super::record_batch_projector::RecordBatchProjector; +use super::type_to_arrow_type; +use crate::spec::{PartitionSpec, Schema, StructType, Type}; +use crate::transform::{BoxedTransformFunction, create_transform_function}; +use crate::{Error, ErrorKind, Result}; + +/// Calculator for partition values in Iceberg tables. +/// +/// This struct handles the projection of source columns and application of +/// partition transforms to compute partition values for a given record batch. +#[derive(Debug)] +pub struct PartitionValueCalculator { + projector: RecordBatchProjector, + transform_functions: Vec, + partition_type: StructType, + partition_arrow_type: DataType, +} + +impl PartitionValueCalculator { + /// Create a new PartitionValueCalculator. + /// + /// # Arguments + /// + /// * `partition_spec` - The partition specification + /// * `table_schema` - The Iceberg table schema + /// + /// # Returns + /// + /// Returns a new `PartitionValueCalculator` instance or an error if initialization fails. + /// + /// # Errors + /// + /// Returns an error if: + /// - The partition spec is unpartitioned + /// - Transform function creation fails + /// - Projector initialization fails + pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> Result { + if partition_spec.is_unpartitioned() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Cannot create partition calculator for unpartitioned table", + )); + } + + // Create transform functions for each partition field + let transform_functions: Vec = partition_spec + .fields() + .iter() + .map(|pf| create_transform_function(&pf.transform)) + .collect::>>()?; + + // Extract source field IDs for projection + let source_field_ids: Vec = partition_spec + .fields() + .iter() + .map(|pf| pf.source_id) + .collect(); + + // Create projector for extracting source columns + let projector = RecordBatchProjector::from_iceberg_schema( + Arc::new(table_schema.clone()), + &source_field_ids, + )?; + + // Get partition type information + let partition_type = partition_spec.partition_type(table_schema)?; + let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?; + + Ok(Self { + projector, + transform_functions, + partition_type, + partition_arrow_type, + }) + } + + /// Get the partition type as an Iceberg StructType. + pub fn partition_type(&self) -> &StructType { + &self.partition_type + } + + /// Get the partition type as an Arrow DataType. + pub fn partition_arrow_type(&self) -> &DataType { + &self.partition_arrow_type + } + + /// Calculate partition values for a record batch. + /// + /// This method: + /// 1. Projects the source columns from the batch + /// 2. Applies partition transforms to each source column + /// 3. Constructs a StructArray containing the partition values + /// + /// # Arguments + /// + /// * `batch` - The record batch to calculate partition values for + /// + /// # Returns + /// + /// Returns an ArrayRef containing a StructArray of partition values, or an error if calculation fails. + /// + /// # Errors + /// + /// Returns an error if: + /// - Column projection fails + /// - Transform application fails + /// - StructArray construction fails + pub fn calculate(&self, batch: &RecordBatch) -> Result { + // Project source columns from the batch + let source_columns = self.projector.project_column(batch.columns())?; + + // Get expected struct fields for the result + let expected_struct_fields = match &self.partition_arrow_type { + DataType::Struct(fields) => fields.clone(), + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + "Expected partition type must be a struct", + )); + } + }; + + // Apply transforms to each source column + let mut partition_values = Vec::with_capacity(self.transform_functions.len()); + for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) { + let partition_value = transform_fn.transform(source_column.clone())?; + partition_values.push(partition_value); + } + + // Construct the StructArray + let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None) + .map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Failed to create partition struct array: {}", e), + ) + })?; + + Ok(Arc::new(struct_array)) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_schema::{Field, Schema as ArrowSchema}; + + use super::*; + use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, Transform}; + + #[test] + fn test_partition_calculator_identity_transform() { + let table_schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build() + .unwrap(); + + let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone())) + .add_partition_field("id", "id_partition", Transform::Identity) + .unwrap() + .build() + .unwrap(); + + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); + + // Verify partition type + assert_eq!(calculator.partition_type().fields().len(), 1); + assert_eq!(calculator.partition_type().fields()[0].name, "id_partition"); + + // Create test batch + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new(arrow_schema, vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ]) + .unwrap(); + + // Calculate partition values + let result = calculator.calculate(&batch).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let id_partition = struct_array + .column_by_name("id_partition") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(id_partition.value(0), 10); + assert_eq!(id_partition.value(1), 20); + assert_eq!(id_partition.value(2), 30); + } + + #[test] + fn test_partition_calculator_unpartitioned_error() { + let table_schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + ]) + .build() + .unwrap(); + + let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone())) + .build() + .unwrap(); + + let result = PartitionValueCalculator::try_new(&partition_spec, &table_schema); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("unpartitioned table") + ); + } +} diff --git a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs index 704a4e9c1..66371fac1 100644 --- a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs +++ b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs @@ -19,137 +19,169 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray}; -use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; use arrow_select::filter::filter_record_batch; -use itertools::Itertools; -use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use super::arrow_struct_to_literal; -use super::record_batch_projector::RecordBatchProjector; -use crate::arrow::type_to_arrow_type; -use crate::spec::{Literal, PartitionSpecRef, SchemaRef, Struct, StructType, Type}; -use crate::transform::{BoxedTransformFunction, create_transform_function}; +use super::partition_value_calculator::PartitionValueCalculator; +use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, StructType}; use crate::{Error, ErrorKind, Result}; +/// Column name for the projected partition values struct +pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition"; + /// The splitter used to split the record batch into multiple record batches by the partition spec. /// 1. It will project and transform the input record batch based on the partition spec, get the partitioned record batch. /// 2. Split the input record batch into multiple record batches based on the partitioned record batch. +/// +/// # Partition Value Modes +/// +/// The splitter supports two modes for obtaining partition values: +/// - **Computed mode** (`calculator` is `Some`): Computes partition values from source columns using transforms +/// - **Pre-computed mode** (`calculator` is `None`): Expects a `_partition` column in the input batch // # TODO // Remove this after partition writer supported. #[allow(dead_code)] pub struct RecordBatchPartitionSplitter { schema: SchemaRef, partition_spec: PartitionSpecRef, - projector: RecordBatchProjector, - transform_functions: Vec, - + calculator: Option, partition_type: StructType, - partition_arrow_type: DataType, } // # TODO // Remove this after partition writer supported. #[allow(dead_code)] impl RecordBatchPartitionSplitter { + /// Create a new RecordBatchPartitionSplitter. + /// + /// # Arguments + /// + /// * `iceberg_schema` - The Iceberg schema reference + /// * `partition_spec` - The partition specification reference + /// * `calculator` - Optional calculator for computing partition values from source columns. + /// - `Some(calculator)`: Compute partition values from source columns using transforms + /// - `None`: Expect a pre-computed `_partition` column in the input batch + /// + /// # Returns + /// + /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails. pub fn new( - input_schema: ArrowSchemaRef, iceberg_schema: SchemaRef, partition_spec: PartitionSpecRef, + calculator: Option, ) -> Result { - let projector = RecordBatchProjector::new( - input_schema, - &partition_spec - .fields() - .iter() - .map(|field| field.source_id) - .collect::>(), - // The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct. - // ref: https://iceberg.apache.org/spec/#partitioning - |field| { - if !field.data_type().is_primitive() { - return Ok(None); - } - field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .map(|s| { - s.parse::() - .map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string())) - }) - .transpose() - }, - |_| true, - )?; - let transform_functions = partition_spec - .fields() - .iter() - .map(|field| create_transform_function(&field.transform)) - .collect::>>()?; - let partition_type = partition_spec.partition_type(&iceberg_schema)?; - let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?; Ok(Self { schema: iceberg_schema, partition_spec, - projector, - transform_functions, + calculator, partition_type, - partition_arrow_type, }) } - fn partition_columns_to_struct(&self, partition_columns: Vec) -> Result> { - let arrow_struct_array = { - let partition_arrow_fields = { - let DataType::Struct(fields) = &self.partition_arrow_type else { - return Err(Error::new( - ErrorKind::DataInvalid, - "The partition arrow type is not a struct type", - )); - }; - fields.clone() - }; - Arc::new(StructArray::try_new( - partition_arrow_fields, - partition_columns, - None, - )?) as ArrayRef - }; - let struct_array = { - let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?; + /// Create a new RecordBatchPartitionSplitter with computed partition values. + /// + /// This is a convenience method that creates a calculator and initializes the splitter + /// to compute partition values from source columns. + /// + /// # Arguments + /// + /// * `iceberg_schema` - The Iceberg schema reference + /// * `partition_spec` - The partition specification reference + /// + /// # Returns + /// + /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails. + pub fn new_with_computed_values( + iceberg_schema: SchemaRef, + partition_spec: PartitionSpecRef, + ) -> Result { + let calculator = PartitionValueCalculator::try_new(&partition_spec, &iceberg_schema)?; + Self::new(iceberg_schema, partition_spec, Some(calculator)) + } + + /// Create a new RecordBatchPartitionSplitter expecting pre-computed partition values. + /// + /// This is a convenience method that initializes the splitter to expect a `_partition` + /// column in the input batches. + /// + /// # Arguments + /// + /// * `iceberg_schema` - The Iceberg schema reference + /// * `partition_spec` - The partition specification reference + /// + /// # Returns + /// + /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails. + pub fn new_with_precomputed_values( + iceberg_schema: SchemaRef, + partition_spec: PartitionSpecRef, + ) -> Result { + Self::new(iceberg_schema, partition_spec, None) + } + + /// Split the record batch into multiple record batches based on the partition spec. + pub fn split(&self, batch: &RecordBatch) -> Result> { + let partition_structs = if let Some(calculator) = &self.calculator { + // Compute partition values from source columns using calculator + let partition_array = calculator.calculate(batch)?; + let struct_array = arrow_struct_to_literal(&partition_array, &self.partition_type)?; + struct_array .into_iter() .map(|s| { - if let Some(s) = s { - if let Literal::Struct(s) = s { - Ok(s) - } else { - Err(Error::new( - ErrorKind::DataInvalid, - "The struct is not a struct literal", - )) - } + if let Some(Literal::Struct(s)) = s { + Ok(s) } else { - Err(Error::new(ErrorKind::DataInvalid, "The struct is null")) + Err(Error::new( + ErrorKind::DataInvalid, + "Partition value is not a struct literal or is null", + )) } }) .collect::>>()? - }; + } else { + // Extract partition values from pre-computed partition column + let partition_column = batch + .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Partition column '{}' not found in batch", + PROJECTED_PARTITION_VALUE_COLUMN + ), + ) + })?; - Ok(struct_array) - } + let partition_struct_array = partition_column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Partition column is not a StructArray", + ) + })?; - /// Split the record batch into multiple record batches based on the partition spec. - pub fn split(&self, batch: &RecordBatch) -> Result> { - let source_columns = self.projector.project_column(batch.columns())?; - let partition_columns = source_columns - .into_iter() - .zip_eq(self.transform_functions.iter()) - .map(|(source_column, transform_function)| transform_function.transform(source_column)) - .collect::>>()?; + let arrow_struct_array = Arc::new(partition_struct_array.clone()) as ArrayRef; + let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?; - let partition_structs = self.partition_columns_to_struct(partition_columns)?; + struct_array + .into_iter() + .map(|s| { + if let Some(Literal::Struct(s)) = s { + Ok(s) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + "Partition value is not a struct literal or is null", + )) + } + }) + .collect::>>()? + }; // Group the batch by row value. let mut group_ids = HashMap::new(); @@ -172,8 +204,15 @@ impl RecordBatchPartitionSplitter { filter.into() }; + // Create PartitionKey from the partition struct + let partition_key = PartitionKey::new( + self.partition_spec.as_ref().clone(), + self.schema.clone(), + row, + ); + // filter the RecordBatch - partition_batches.push((row, filter_record_batch(batch, &filter_array)?)); + partition_batches.push((partition_key, filter_record_batch(batch, &filter_array)?)); } Ok(partition_batches) @@ -185,11 +224,13 @@ mod tests { use std::sync::Arc; use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_schema::DataType; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use super::*; use crate::arrow::schema_to_arrow_schema; use crate::spec::{ - NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Transform, + NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct, Transform, Type, UnboundPartitionField, }; @@ -227,14 +268,14 @@ mod tests { .build() .unwrap(), ); - let input_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap()); let partition_splitter = - RecordBatchPartitionSplitter::new(input_schema.clone(), schema.clone(), partition_spec) + RecordBatchPartitionSplitter::new_with_computed_values(schema.clone(), partition_spec) .expect("Failed to create splitter"); + let arrow_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap()); let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]); - let batch = RecordBatch::try_new(input_schema.clone(), vec![ + let batch = RecordBatch::try_new(arrow_schema.clone(), vec![ Arc::new(id_array), Arc::new(data_array), ]) @@ -243,8 +284,8 @@ mod tests { let mut partitioned_batches = partition_splitter .split(&batch) .expect("Failed to split RecordBatch"); - partitioned_batches.sort_by_key(|(row, _)| { - if let PrimitiveLiteral::Int(i) = row.fields()[0] + partitioned_batches.sort_by_key(|(partition_key, _)| { + if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0] .as_ref() .unwrap() .as_primitive_literal() @@ -260,7 +301,7 @@ mod tests { // check the first partition let expected_id_array = Int32Array::from(vec![1, 1, 1]); let expected_data_array = StringArray::from(vec!["a", "c", "g"]); - let expected_batch = RecordBatch::try_new(input_schema.clone(), vec![ + let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![ Arc::new(expected_id_array), Arc::new(expected_data_array), ]) @@ -271,7 +312,7 @@ mod tests { // check the second partition let expected_id_array = Int32Array::from(vec![2, 2]); let expected_data_array = StringArray::from(vec!["b", "e"]); - let expected_batch = RecordBatch::try_new(input_schema.clone(), vec![ + let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![ Arc::new(expected_id_array), Arc::new(expected_data_array), ]) @@ -282,7 +323,7 @@ mod tests { // check the third partition let expected_id_array = Int32Array::from(vec![3, 3]); let expected_data_array = StringArray::from(vec!["d", "f"]); - let expected_batch = RecordBatch::try_new(input_schema.clone(), vec![ + let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![ Arc::new(expected_id_array), Arc::new(expected_data_array), ]) @@ -292,7 +333,7 @@ mod tests { let partition_values = partitioned_batches .iter() - .map(|(row, _)| row.clone()) + .map(|(partition_key, _)| partition_key.data().clone()) .collect::>(); // check partition value is struct(1), struct(2), struct(3) assert_eq!(partition_values, vec![ @@ -301,4 +342,144 @@ mod tests { Struct::from_iter(vec![Some(Literal::int(3))]), ]); } + + #[test] + fn test_record_batch_partition_split_with_partition_column() { + use arrow_array::StructArray; + use arrow_schema::{Field, Schema as ArrowSchema}; + + let schema = Arc::new( + Schema::builder() + .with_fields(vec![ + NestedField::required( + 1, + "id", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(), + ); + let partition_spec = Arc::new( + PartitionSpecBuilder::new(schema.clone()) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(), + ); + + // Create input schema with _partition column + // Note: partition field IDs start from 1000 by default + let partition_field = Field::new("id_bucket", DataType::Int32, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1000".to_string())]), + ); + let partition_struct_field = Field::new( + PROJECTED_PARTITION_VALUE_COLUMN, + DataType::Struct(vec![partition_field.clone()].into()), + false, + ); + + let input_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + partition_struct_field, + ])); + + // Create splitter expecting pre-computed partition column + let partition_splitter = RecordBatchPartitionSplitter::new_with_precomputed_values( + schema.clone(), + partition_spec, + ) + .expect("Failed to create splitter"); + + // Create test data with pre-computed partition column + let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); + let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]); + + // Create partition column (same values as id for Identity transform) + let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); + let partition_struct = StructArray::from(vec![( + Arc::new(partition_field), + Arc::new(partition_values) as ArrayRef, + )]); + + let batch = RecordBatch::try_new(input_schema.clone(), vec![ + Arc::new(id_array), + Arc::new(data_array), + Arc::new(partition_struct), + ]) + .expect("Failed to create RecordBatch"); + + // Split using the pre-computed partition column + let mut partitioned_batches = partition_splitter + .split(&batch) + .expect("Failed to split RecordBatch"); + + partitioned_batches.sort_by_key(|(partition_key, _)| { + if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0] + .as_ref() + .unwrap() + .as_primitive_literal() + .unwrap() + { + i + } else { + panic!("The partition value is not a int"); + } + }); + + assert_eq!(partitioned_batches.len(), 3); + + // Helper to extract id and name values from a batch + let extract_values = |batch: &RecordBatch| -> (Vec, Vec) { + let id_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let name_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + ( + id_col.values().to_vec(), + name_col.iter().map(|s| s.unwrap().to_string()).collect(), + ) + }; + + // Verify partition 1: id=1, names=["a", "c", "g"] + let (key, batch) = &partitioned_batches[0]; + assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(1))])); + let (ids, names) = extract_values(batch); + assert_eq!(ids, vec![1, 1, 1]); + assert_eq!(names, vec!["a", "c", "g"]); + + // Verify partition 2: id=2, names=["b", "e"] + let (key, batch) = &partitioned_batches[1]; + assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(2))])); + let (ids, names) = extract_values(batch); + assert_eq!(ids, vec![2, 2]); + assert_eq!(names, vec!["b", "e"]); + + // Verify partition 3: id=3, names=["d", "f"] + let (key, batch) = &partitioned_batches[2]; + assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(3))])); + let (ids, names) = extract_values(batch); + assert_eq!(ids, vec![3, 3]); + assert_eq!(names, vec!["d", "f"]); + } } diff --git a/crates/integrations/datafusion/src/physical_plan/project.rs b/crates/integrations/datafusion/src/physical_plan/project.rs index 4bfe8192b..17492176a 100644 --- a/crates/integrations/datafusion/src/physical_plan/project.rs +++ b/crates/integrations/datafusion/src/physical_plan/project.rs @@ -19,24 +19,19 @@ use std::sync::Arc; -use datafusion::arrow::array::{ArrayRef, RecordBatch, StructArray}; +use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema}; use datafusion::common::Result as DFResult; -use datafusion::error::DataFusionError; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::expressions::Column; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::{ColumnarValue, ExecutionPlan}; -use iceberg::arrow::record_batch_projector::RecordBatchProjector; -use iceberg::spec::{PartitionSpec, Schema}; +use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator}; +use iceberg::spec::PartitionSpec; use iceberg::table::Table; -use iceberg::transform::BoxedTransformFunction; use crate::to_datafusion_error; -/// Column name for the combined partition values struct -const PARTITION_VALUES_COLUMN: &str = "_partition"; - /// Extends an ExecutionPlan with partition value calculations for Iceberg tables. /// /// This function takes an input ExecutionPlan and extends it with an additional column @@ -65,12 +60,9 @@ pub fn project_with_partition( let input_schema = input.schema(); // TODO: Validate that input_schema matches the Iceberg table schema. // See: https://github.com/apache/iceberg-rust/issues/1752 - let partition_type = build_partition_type(partition_spec, table_schema.as_ref())?; - let calculator = PartitionValueCalculator::new( - partition_spec.as_ref().clone(), - table_schema.as_ref().clone(), - partition_type, - )?; + let calculator = + PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref()) + .map_err(to_datafusion_error)?; let mut projection_exprs: Vec<(Arc, String)> = Vec::with_capacity(input_schema.fields().len() + 1); @@ -80,8 +72,8 @@ pub fn project_with_partition( projection_exprs.push((column_expr, field.name().clone())); } - let partition_expr = Arc::new(PartitionExpr::new(calculator)); - projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string())); + let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone())); + projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string())); let projection = ProjectionExec::try_new(projection_exprs, input)?; Ok(Arc::new(projection)) @@ -91,21 +83,24 @@ pub fn project_with_partition( #[derive(Debug, Clone)] struct PartitionExpr { calculator: Arc, + partition_spec: Arc, } impl PartitionExpr { - fn new(calculator: PartitionValueCalculator) -> Self { + fn new(calculator: PartitionValueCalculator, partition_spec: Arc) -> Self { Self { calculator: Arc::new(calculator), + partition_spec, } } } // Manual PartialEq/Eq implementations for pointer-based equality -// (two PartitionExpr are equal if they share the same calculator instance) +// (two PartitionExpr are equal if they share the same calculator and partition_spec instances) impl PartialEq for PartitionExpr { fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.calculator, &other.calculator) + && Arc::ptr_eq(&self.partition_spec, &other.partition_spec) } } @@ -117,7 +112,7 @@ impl PhysicalExpr for PartitionExpr { } fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult { - Ok(self.calculator.partition_type.clone()) + Ok(self.calculator.partition_arrow_type().clone()) } fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult { @@ -125,7 +120,10 @@ impl PhysicalExpr for PartitionExpr { } fn evaluate(&self, batch: &RecordBatch) -> DFResult { - let array = self.calculator.calculate(batch)?; + let array = self + .calculator + .calculate(batch) + .map_err(to_datafusion_error)?; Ok(ColumnarValue::Array(array)) } @@ -142,7 +140,6 @@ impl PhysicalExpr for PartitionExpr { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let field_names: Vec = self - .calculator .partition_spec .fields() .iter() @@ -155,7 +152,6 @@ impl PhysicalExpr for PartitionExpr { impl std::fmt::Display for PartitionExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let field_names: Vec<&str> = self - .calculator .partition_spec .fields() .iter() @@ -167,110 +163,18 @@ impl std::fmt::Display for PartitionExpr { impl std::hash::Hash for PartitionExpr { fn hash(&self, state: &mut H) { - // Two PartitionExpr are equal if they share the same calculator Arc + // Two PartitionExpr are equal if they share the same calculator and partition_spec Arcs Arc::as_ptr(&self.calculator).hash(state); + Arc::as_ptr(&self.partition_spec).hash(state); } } -/// Calculator for partition values in Iceberg tables -#[derive(Debug)] -struct PartitionValueCalculator { - partition_spec: PartitionSpec, - partition_type: DataType, - projector: RecordBatchProjector, - transform_functions: Vec, -} - -impl PartitionValueCalculator { - fn new( - partition_spec: PartitionSpec, - table_schema: Schema, - partition_type: DataType, - ) -> DFResult { - if partition_spec.is_unpartitioned() { - return Err(DataFusionError::Internal( - "Cannot create partition calculator for unpartitioned table".to_string(), - )); - } - - let transform_functions: Result, _> = partition_spec - .fields() - .iter() - .map(|pf| iceberg::transform::create_transform_function(&pf.transform)) - .collect(); - - let transform_functions = transform_functions.map_err(to_datafusion_error)?; - - let source_field_ids: Vec = partition_spec - .fields() - .iter() - .map(|pf| pf.source_id) - .collect(); - - let projector = RecordBatchProjector::from_iceberg_schema( - Arc::new(table_schema.clone()), - &source_field_ids, - ) - .map_err(to_datafusion_error)?; - - Ok(Self { - partition_spec, - partition_type, - projector, - transform_functions, - }) - } - - fn calculate(&self, batch: &RecordBatch) -> DFResult { - let source_columns = self - .projector - .project_column(batch.columns()) - .map_err(to_datafusion_error)?; - - let expected_struct_fields = match &self.partition_type { - DataType::Struct(fields) => fields.clone(), - _ => { - return Err(DataFusionError::Internal( - "Expected partition type must be a struct".to_string(), - )); - } - }; - - let mut partition_values = Vec::with_capacity(self.partition_spec.fields().len()); - - for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) { - let partition_value = transform_fn - .transform(source_column.clone()) - .map_err(to_datafusion_error)?; - - partition_values.push(partition_value); - } - - let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None) - .map_err(|e| DataFusionError::ArrowError(e, None))?; - - Ok(Arc::new(struct_array)) - } -} - -fn build_partition_type( - partition_spec: &PartitionSpec, - table_schema: &Schema, -) -> DFResult { - let partition_struct_type = partition_spec - .partition_type(table_schema) - .map_err(to_datafusion_error)?; - - iceberg::arrow::type_to_arrow_type(&iceberg::spec::Type::Struct(partition_struct_type)) - .map_err(to_datafusion_error) -} - #[cfg(test)] mod tests { - use datafusion::arrow::array::Int32Array; + use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray}; use datafusion::arrow::datatypes::{Field, Fields}; use datafusion::physical_plan::empty::EmptyExec; - use iceberg::spec::{NestedField, PrimitiveType, StructType, Transform, Type}; + use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type}; use super::*; @@ -291,20 +195,11 @@ mod tests { .build() .unwrap(); - let _arrow_schema = Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), - ])); - - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = PartitionValueCalculator::new( - partition_spec.clone(), - table_schema, - partition_type.clone(), - ) - .unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); - assert_eq!(calculator.partition_type, partition_type); + // Verify partition type + assert_eq!(calculator.partition_type().fields().len(), 1); + assert_eq!(calculator.partition_type().fields()[0].name, "id_partition"); } #[test] @@ -318,11 +213,13 @@ mod tests { .build() .unwrap(); - let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone())) - .add_partition_field("id", "id_partition", Transform::Identity) - .unwrap() - .build() - .unwrap(); + let partition_spec = Arc::new( + iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone())) + .add_partition_field("id", "id_partition", Transform::Identity) + .unwrap() + .build() + .unwrap(), + ); let arrow_schema = Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), @@ -331,9 +228,7 @@ mod tests { let input = Arc::new(EmptyExec::new(arrow_schema.clone())); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type).unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); let mut projection_exprs: Vec<(Arc, String)> = Vec::with_capacity(arrow_schema.fields().len() + 1); @@ -342,8 +237,8 @@ mod tests { projection_exprs.push((column_expr, field.name().clone())); } - let partition_expr = Arc::new(PartitionExpr::new(calculator)); - projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string())); + let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec)); + projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string())); let projection = ProjectionExec::try_new(projection_exprs, input).unwrap(); let result = Arc::new(projection); @@ -384,11 +279,10 @@ mod tests { ]) .unwrap(); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type.clone()) - .unwrap(); - let expr = PartitionExpr::new(calculator); + let partition_spec = Arc::new(partition_spec); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); + let partition_type = calculator.partition_arrow_type().clone(); + let expr = PartitionExpr::new(calculator, partition_spec); assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type); assert!(!expr.nullable(&arrow_schema).unwrap()); @@ -469,9 +363,7 @@ mod tests { ]) .unwrap(); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type).unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); let array = calculator.calculate(&batch).unwrap(); let struct_array = array.as_any().downcast_ref::().unwrap();