Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions datafusion/substrait/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::common::{DataFusionError, HashMap, plan_err};
use substrait::proto::extensions::SimpleExtensionDeclaration;
use datafusion::common::{HashMap, plan_err};
use substrait::proto::extensions::{SimpleExtensionDeclaration, SimpleExtensionUri};
use substrait::proto::extensions::simple_extension_declaration::{
ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
};

/// Arrow's official Substrait extension types URI
pub const ARROW_EXTENSION_TYPES_URI: &str =
"https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml";

/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define
/// behavior of plans in addition to what's supported directly by the protobuf definitions.
/// That includes functions, but also provides support for custom types and variations for existing
/// types. This structs facilitates the use of these extensions in DataFusion.
/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
#[derive(Default, Debug, PartialEq)]
pub struct Extensions {
pub uris: HashMap<u32, String>, // anchor -> URI
pub functions: HashMap<u32, String>, // anchor -> function name
pub types: HashMap<u32, String>, // anchor -> type name
pub types: HashMap<u32, (String, u32)>, // anchor -> (type name, uri_anchor)
pub type_variations: HashMap<u32, String>, // anchor -> type variation name
}

Expand Down Expand Up @@ -62,39 +66,61 @@ impl Extensions {
}
}

/// Registers a type and returns the anchor (reference) to it. If the type has already
/// Registers a URI and returns the anchor (reference) to it. If the URI has already
/// been registered, it returns the existing anchor.
pub fn register_uri(&mut self, uri: &str) -> u32 {
match self.uris.iter().find(|(_, u)| *u == uri) {
Some((uri_anchor, _)) => *uri_anchor, // URI has been registered
None => {
// URI has NOT been registered
let uri_anchor = self.uris.len() as u32;
self.uris.insert(uri_anchor, uri.to_string());
uri_anchor
}
}
}

/// Registers an Arrow extension type and returns the anchor (reference) to it.
/// Uses the official Arrow extension types URI by default.
/// If the type has already been registered, it returns the existing anchor.
pub fn register_type(&mut self, type_name: &str) -> u32 {
let type_name = type_name.to_lowercase();
match self.types.iter().find(|(_, t)| *t == &type_name) {
let uri_anchor = self.register_uri(ARROW_EXTENSION_TYPES_URI);
match self.types.iter().find(|(_, (t, _))| *t == type_name) {
Some((type_anchor, _)) => *type_anchor, // Type has been registered
None => {
// Type has NOT been registered
let type_anchor = self.types.len() as u32;
self.types.insert(type_anchor, type_name.clone());
self.types.insert(type_anchor, (type_name.clone(), uri_anchor));
type_anchor
}
}
}
}

impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
type Error = DataFusionError;

fn try_from(
value: &Vec<SimpleExtensionDeclaration>,
impl Extensions {
/// Parse extensions from Substrait plan components
pub fn from_substrait(
extension_uris: &[SimpleExtensionUri],
extensions: &[SimpleExtensionDeclaration],
) -> datafusion::common::Result<Self> {
let mut uris = HashMap::new();
let mut functions = HashMap::new();
let mut types = HashMap::new();
let mut type_variations = HashMap::new();

for ext in value {
for uri in extension_uris {
uris.insert(uri.extension_uri_anchor, uri.uri.clone());
}

for ext in extensions {
match &ext.mapping_type {
Some(MappingType::ExtensionFunction(ext_f)) => {
functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
}
Some(MappingType::ExtensionType(ext_t)) => {
types.insert(ext_t.type_anchor, ext_t.name.to_owned());
let uri_anchor = ext_t.extension_urn_reference;
types.insert(ext_t.type_anchor, (ext_t.name.to_owned(), uri_anchor));
}
Some(MappingType::ExtensionTypeVariation(ext_v)) => {
type_variations
Expand All @@ -105,11 +131,23 @@ impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
}

Ok(Extensions {
uris,
functions,
types,
type_variations,
})
}

/// Get extension URIs for Substrait plan
pub fn to_extension_uris(&self) -> Vec<SimpleExtensionUri> {
self.uris
.iter()
.map(|(anchor, uri)| SimpleExtensionUri {
extension_uri_anchor: *anchor,
uri: uri.clone(),
})
.collect()
}
}

impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
Expand All @@ -131,10 +169,10 @@ impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
extensions.push(simple_extension);
}

for (t_anchor, t_name) in val.types {
for (t_anchor, (t_name, uri_anchor)) in val.types {
let type_extension = ExtensionType {
extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
extension_urn_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
extension_uri_reference: uri_anchor,
extension_urn_reference: uri_anchor,
type_anchor: t_anchor,
name: t_name,
};
Expand Down
88 changes: 86 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer/expr/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
use crate::logical_plan::consumer::SubstraitConsumer;
use crate::logical_plan::consumer::types::from_substrait_type;
use crate::logical_plan::consumer::utils::{DEFAULT_TIMEZONE, next_struct_field_name};
use crate::variation_const::FLOAT_16_TYPE_NAME;
use crate::variation_const::{
FLOAT_16_TYPE_NAME, LARGE_BINARY_TYPE_NAME, LARGE_STRING_TYPE_NAME, U16_TYPE_NAME,
U32_TYPE_NAME, U64_TYPE_NAME, U8_TYPE_NAME,
};
#[expect(deprecated)]
use crate::variation_const::{
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
Expand Down Expand Up @@ -474,7 +477,7 @@ pub(crate) fn from_substrait_literal(
)))
};

if let Some(name) = consumer
if let Some((name, _uri_anchor)) = consumer
.get_extensions()
.types
.get(&user_defined.type_reference)
Expand Down Expand Up @@ -515,6 +518,67 @@ pub(crate) fn from_substrait_literal(
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => {
interval_month_day_nano(user_defined)?
}
// Unsigned integer literals use google.protobuf.UInt64Value
U8_TYPE_NAME => {
let value = decode_uint64_literal(user_defined, "u8")?;
return Ok(ScalarValue::UInt8(Some(value as u8)));
}
U16_TYPE_NAME => {
let value = decode_uint64_literal(user_defined, "u16")?;
return Ok(ScalarValue::UInt16(Some(value as u16)));
}
U32_TYPE_NAME => {
let value = decode_uint64_literal(user_defined, "u32")?;
return Ok(ScalarValue::UInt32(Some(value as u32)));
}
U64_TYPE_NAME => {
let value = decode_uint64_literal(user_defined, "u64")?;
return Ok(ScalarValue::UInt64(Some(value)));
}
// Large string literals use google.protobuf.StringValue
LARGE_STRING_TYPE_NAME => {
let Some(value) = user_defined.val.as_ref() else {
return substrait_err!("large_string value is empty");
};
let Val::Value(value_any) = value else {
return substrait_err!("large_string value is not a value type literal");
};
if value_any.type_url != "google.protobuf.StringValue" {
return substrait_err!(
"large_string value is not a google.protobuf.StringValue"
);
}
let decoded_value =
pbjson_types::StringValue::decode(value_any.value.clone())
.map_err(|err| {
substrait_datafusion_err!(
"Failed to decode large_string value: {err}"
)
})?;
return Ok(ScalarValue::LargeUtf8(Some(decoded_value.value)));
}
// Large binary literals use google.protobuf.BytesValue
LARGE_BINARY_TYPE_NAME => {
let Some(value) = user_defined.val.as_ref() else {
return substrait_err!("large_binary value is empty");
};
let Val::Value(value_any) = value else {
return substrait_err!("large_binary value is not a value type literal");
};
if value_any.type_url != "google.protobuf.BytesValue" {
return substrait_err!(
"large_binary value is not a google.protobuf.BytesValue"
);
}
let decoded_value =
pbjson_types::BytesValue::decode(value_any.value.clone())
.map_err(|err| {
substrait_datafusion_err!(
"Failed to decode large_binary value: {err}"
)
})?;
return Ok(ScalarValue::LargeBinary(Some(decoded_value.value.into())));
}
_ => {
return not_impl_err!(
"Unsupported Substrait user defined type with ref {} and name {}",
Expand Down Expand Up @@ -580,6 +644,26 @@ pub(crate) fn from_substrait_literal(
Ok(scalar_value)
}

/// Helper function to decode unsigned integer literals from google.protobuf.UInt64Value
fn decode_uint64_literal(
user_defined: &proto::expression::literal::UserDefined,
type_name: &str,
) -> datafusion::common::Result<u64> {
let Some(value) = user_defined.val.as_ref() else {
return substrait_err!("{type_name} value is empty");
};
let Val::Value(value_any) = value else {
return substrait_err!("{type_name} value is not a value type literal");
};
if value_any.type_url != "google.protobuf.UInt64Value" {
return substrait_err!("{type_name} value is not a google.protobuf.UInt64Value");
}
let decoded_value = pbjson_types::UInt64Value::decode(value_any.value.clone()).map_err(
|err| substrait_datafusion_err!("Failed to decode {type_name} value: {err}"),
)?;
Ok(decoded_value.value)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/consumer/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub async fn from_substrait_extended_expr(
extended_expr: &ExtendedExpression,
) -> datafusion::common::Result<ExprContainer> {
// Register function extension
let extensions = Extensions::try_from(&extended_expr.extensions)?;
let extensions = Extensions::from_substrait(&extended_expr.extension_uris, &extended_expr.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/consumer/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub async fn from_substrait_plan(
plan: &Plan,
) -> datafusion::common::Result<LogicalPlan> {
// Register function extension
let extensions = Extensions::try_from(&plan.extensions)?;
let extensions = Extensions::from_substrait(&plan.extension_uris, &plan.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ use substrait::proto::{
///
/// // and handlers for user-define types
/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result<DataType> {
/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap();
/// match type_string.as_str() {
/// let (type_name, _uri_anchor) = self.extensions.types.get(&typ.type_reference).unwrap();
/// match type_name.as_str() {
/// "u!foo" => not_impl_err!("handle foo conversion"),
/// "u!bar" => not_impl_err!("handle bar conversion"),
/// _ => substrait_err!("unexpected type")
Expand All @@ -141,8 +141,8 @@ use substrait::proto::{
///
/// // and user-defined literals
/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result<ScalarValue> {
/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap();
/// match type_string.as_str() {
/// let (type_name, _uri_anchor) = self.extensions.types.get(&literal.type_reference).unwrap();
/// match type_name.as_str() {
/// "u!foo" => not_impl_err!("handle foo conversion"),
/// "u!bar" => not_impl_err!("handle bar conversion"),
/// _ => substrait_err!("unexpected type")
Expand Down
Loading