Skip to content

Commit

Permalink
ColumnName tracks a path of field names instead of a simple string (#445
Browse files Browse the repository at this point in the history
)

## What changes are proposed in this pull request?

Previously, `ColumnName` tracked a simple string, which could be split
at `.` to obtain a path of field names. But this breaks down if a field
name contains any special characters that might interfere with the
interpretation of a dot character.

To solve this, update `ColumnName` to track an actual path of field
names instead. Update all call sites as needed to support the new idiom.

This PR also includes code for reliably parsing strings into
`ColumnName` using period as field separator and backticks as delimiters
for field names containing special characters, e.g:
```rust
assert_eq!(ColumnName::new(["a", "b c", "d"]).to_string(), "a.`b c`.d");
``` 

NOTE: This change does _not_ magically make all operations
nesting-aware. For example, code that loops over the field names of a
`StructType` will continue to see nested column names as not-matched.
Fixing those call sites is left as future work, tho obvious ones are
flagged here as TODO.

Resolves #443

### This PR affects the following public APIs

The "shape" of `ColumnName` changes from string-like to
slice-of-string-like, and its methods change accordingly. This change is
needed because otherwise we cannot reliably handle arbitrary field
names.

## How was this change tested?

Extensive new and existing unit tests.
  • Loading branch information
scovich authored Nov 6, 2024
1 parent 0214a96 commit 4466509
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 179 deletions.
22 changes: 11 additions & 11 deletions ffi/src/expressions/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ pub struct EnginePredicate {
extern "C" fn(predicate: *mut c_void, state: &mut KernelExpressionVisitorState) -> usize,
}

fn wrap_expression(state: &mut KernelExpressionVisitorState, expr: Expression) -> usize {
state.inflight_expressions.insert(expr)
fn wrap_expression(state: &mut KernelExpressionVisitorState, expr: impl Into<Expression>) -> usize {
state.inflight_expressions.insert(expr.into())
}

pub fn unwrap_kernel_expression(
Expand Down Expand Up @@ -149,7 +149,7 @@ fn visit_expression_column_impl(
name: DeltaResult<&str>,
) -> DeltaResult<usize> {
// TODO: FIXME: This is incorrect if any field name in the column path contains a period.
let name = ColumnName::new(name?.split('.')).into();
let name = ColumnName::from_naive_str_split(name?);
Ok(wrap_expression(state, name))
}

Expand Down Expand Up @@ -184,7 +184,7 @@ fn visit_expression_literal_string_impl(
state: &mut KernelExpressionVisitorState,
value: DeltaResult<String>,
) -> DeltaResult<usize> {
Ok(wrap_expression(state, Expression::literal(value?)))
Ok(wrap_expression(state, value?))
}

// We need to get parse.expand working to be able to macro everything below, see issue #255
Expand All @@ -194,53 +194,53 @@ pub extern "C" fn visit_expression_literal_int(
state: &mut KernelExpressionVisitorState,
value: i32,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_long(
state: &mut KernelExpressionVisitorState,
value: i64,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_short(
state: &mut KernelExpressionVisitorState,
value: i16,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_byte(
state: &mut KernelExpressionVisitorState,
value: i8,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_float(
state: &mut KernelExpressionVisitorState,
value: f32,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_double(
state: &mut KernelExpressionVisitorState,
value: f64,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}

#[no_mangle]
pub extern "C" fn visit_expression_literal_bool(
state: &mut KernelExpressionVisitorState,
value: bool,
) -> usize {
wrap_expression(state, Expression::literal(value))
wrap_expression(state, value)
}
2 changes: 1 addition & 1 deletion ffi/src/expressions/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ pub unsafe extern "C" fn visit_expression(
visit_expression_scalar(visitor, scalar, sibling_list_id)
}
Expression::Column(name) => {
let name = name.as_str();
let name = name.to_string();
let name = kernel_string_slice!(name);
call!(visitor, visit_column, sibling_list_id, name)
}
Expand Down
10 changes: 6 additions & 4 deletions kernel/src/actions/set_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use std::sync::{Arc, LazyLock};

use crate::actions::visitors::SetTransactionVisitor;
use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME};
use crate::expressions::column_expr;
use crate::snapshot::Snapshot;
use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef};
use crate::{DeltaResult, Engine, EngineData, Expression as Expr, ExpressionRef, SchemaRef};

pub use crate::actions::visitors::SetTransactionMap;
pub struct SetTransactionScanner {
Expand Down Expand Up @@ -53,8 +52,11 @@ impl SetTransactionScanner {
// checkpoint part when patitioned by `add.path` like the Delta spec requires. There's no
// point filtering by a particular app id, even if we have one, because app ids are all in
// the a single checkpoint part having large min/max range (because they're usually uuids).
static META_PREDICATE: LazyLock<Option<ExpressionRef>> =
LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null())));
static META_PREDICATE: LazyLock<Option<ExpressionRef>> = LazyLock::new(|| {
Some(Arc::new(
Expr::column([SET_TRANSACTION_NAME, "appId"]).is_not_null(),
))
});
self.snapshot
.log_segment
.replay(engine, schema.clone(), schema, META_PREDICATE.clone())
Expand Down
10 changes: 3 additions & 7 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,8 @@ impl ProvidesColumnByName for StructArray {
// }
// ```
// The path ["b", "d", "f"] would retrieve the int64 column while ["a", "b"] would produce an error.
fn extract_column<'a>(
mut parent: &dyn ProvidesColumnByName,
mut field_names: impl Iterator<Item = &'a str>,
) -> DeltaResult<ArrayRef> {
fn extract_column(mut parent: &dyn ProvidesColumnByName, col: &[String]) -> DeltaResult<ArrayRef> {
let mut field_names = col.iter();
let Some(mut field_name) = field_names.next() else {
return Err(ArrowError::SchemaError("Empty column path".to_string()))?;
};
Expand Down Expand Up @@ -196,9 +194,7 @@ fn evaluate_expression(
use Expression::*;
match (expression, result_type) {
(Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?),
// TODO properly handle nested columns
// https://github.com/delta-incubator/delta-kernel-rs/issues/86
(Column(name), _) => extract_column(batch, name.split('.')),
(Column(name), _) => extract_column(batch, name),
(Struct(fields), Some(DataType::Struct(output_schema))) => {
let columns = fields
.iter()
Expand Down
28 changes: 16 additions & 12 deletions kernel/src/engine/parquet_row_group_skipping.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! An implementation of parquet row group skipping using data skipping predicates over footer stats.
use crate::engine::parquet_stats_skipping::{col_name_to_path, ParquetStatsSkippingFilter};
use crate::expressions::{Expression, Scalar};
use crate::engine::parquet_stats_skipping::ParquetStatsSkippingFilter;
use crate::expressions::{ColumnName, Expression, Scalar};
use crate::schema::{DataType, PrimitiveType};
use chrono::{DateTime, Days};
use parquet::arrow::arrow_reader::ArrowReaderBuilder;
use parquet::file::metadata::RowGroupMetaData;
use parquet::file::statistics::Statistics;
use parquet::schema::types::{ColumnDescPtr, ColumnPath};
use parquet::schema::types::ColumnDescPtr;
use std::collections::{HashMap, HashSet};
use tracing::debug;

Expand Down Expand Up @@ -41,7 +41,7 @@ impl<T> ParquetRowGroupSkipping for ArrowReaderBuilder<T> {
/// corresponding field index, for O(1) stats lookups.
struct RowGroupFilter<'a> {
row_group: &'a RowGroupMetaData,
field_indices: HashMap<ColumnPath, usize>,
field_indices: HashMap<ColumnName, usize>,
}

impl<'a> RowGroupFilter<'a> {
Expand All @@ -59,7 +59,7 @@ impl<'a> RowGroupFilter<'a> {
}

/// Returns `None` if the column doesn't exist and `Some(None)` if the column has no stats.
fn get_stats(&self, col: &ColumnPath) -> Option<Option<&Statistics>> {
fn get_stats(&self, col: &ColumnName) -> Option<Option<&Statistics>> {
self.field_indices
.get(col)
.map(|&i| self.row_group.column(i).statistics())
Expand Down Expand Up @@ -93,7 +93,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> {
// NOTE: This code is highly redundant with [`get_max_stat_value`] below, but parquet
// ValueStatistics<T> requires T to impl a private trait, so we can't factor out any kind of
// helper method. And macros are hard enough to read that it's not worth defining one.
fn get_min_stat_value(&self, col: &ColumnPath, data_type: &DataType) -> Option<Scalar> {
fn get_min_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, self.get_stats(col)??) {
(String, Statistics::ByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(),
Expand Down Expand Up @@ -135,7 +135,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> {
Some(value)
}

fn get_max_stat_value(&self, col: &ColumnPath, data_type: &DataType) -> Option<Scalar> {
fn get_max_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, self.get_stats(col)??) {
(String, Statistics::ByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(),
Expand Down Expand Up @@ -177,7 +177,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> {
Some(value)
}

fn get_nullcount_stat_value(&self, col: &ColumnPath) -> Option<i64> {
fn get_nullcount_stat_value(&self, col: &ColumnName) -> Option<i64> {
// NOTE: Stats for any given column are optional, which may produce a NULL nullcount. But if
// the column itself is missing, then we know all values are implied to be NULL.
//
Expand Down Expand Up @@ -221,13 +221,13 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> {
pub(crate) fn compute_field_indices(
fields: &[ColumnDescPtr],
expression: &Expression,
) -> HashMap<ColumnPath, usize> {
fn do_recurse(expression: &Expression, cols: &mut HashSet<ColumnPath>) {
) -> HashMap<ColumnName, usize> {
fn do_recurse(expression: &Expression, cols: &mut HashSet<ColumnName>) {
use Expression::*;
let mut recurse = |expr| do_recurse(expr, cols); // simplifies the call sites below
match expression {
Literal(_) => {}
Column(name) => cols.extend([col_name_to_path(name)]), // returns `()`, unlike `insert`
Column(name) => cols.extend([name.clone()]), // returns `()`, unlike `insert`
Struct(fields) => fields.iter().for_each(recurse),
UnaryOperation { expr, .. } => recurse(expr),
BinaryOperation { left, right, .. } => [left, right].iter().for_each(|e| recurse(e)),
Expand All @@ -245,6 +245,10 @@ pub(crate) fn compute_field_indices(
fields
.iter()
.enumerate()
.filter_map(|(i, f)| requested_columns.take(f.path()).map(|path| (path, i)))
.filter_map(|(i, f)| {
requested_columns
.take(f.path().parts())
.map(|path| (path, i))
})
.collect()
}
Loading

0 comments on commit 4466509

Please sign in to comment.