Skip to content

Commit

Permalink
fix(cubesql) Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldheinrichs authored and mcheshkov committed Jan 7, 2025
1 parent 2145a19 commit 30128c0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
32 changes: 15 additions & 17 deletions rust/cubesql/cubesql/src/compile/engine/udf/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3174,9 +3174,10 @@ pub fn create_col_description_udf() -> ScalarUDF {

ScalarUDF::new(
"col_description",
&Signature::one_of(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
], Volatility::Immutable),
&Signature::one_of(
vec![TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8])],
Volatility::Immutable,
),
&return_type,
&fun,
)
Expand Down Expand Up @@ -3233,31 +3234,32 @@ pub fn create_format_udf() -> ScalarUDF {
let str_arr = downcast_string_arg!(arg, "arg", i32);
if str_arr.is_null(i) {
return Err(DataFusionError::Execution(
"NULL values cannot be formatted as identifiers".to_string(),
"NULL values cannot be formatted as identifiers"
.to_string(),
));
}
str_arr.value(i).to_string()
}
_ => {
// For other types, try to convert to string
let str_arr = cast(&arg, &DataType::Utf8)?;
let str_arr = str_arr
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let str_arr =
str_arr.as_any().downcast_ref::<StringArray>().unwrap();
if str_arr.is_null(i) {
return Err(DataFusionError::Execution(
"NULL values cannot be formatted as identifiers".to_string(),
"NULL values cannot be formatted as identifiers"
.to_string(),
));
}
str_arr.value(i).to_string()
}
};

// Quote identifier if necessary
let needs_quoting = !value.chars().all(|c| {
c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_'
}) || value.is_empty();
let needs_quoting = !value
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
|| value.is_empty();

if needs_quoting {
result.push('"');
Expand Down Expand Up @@ -3296,16 +3298,12 @@ pub fn create_format_udf() -> ScalarUDF {

ScalarUDF::new(
"format",
&Signature::variadic(
vec![DataType::Utf8],
Volatility::Immutable,
),
&Signature::variadic(vec![DataType::Utf8], Volatility::Immutable),
&return_type,
&fun,
)
}


pub fn create_json_build_object_udf() -> ScalarUDF {
let fun = make_scalar_function(move |_args: &[ArrayRef]| {
// TODO: Implement
Expand Down
37 changes: 31 additions & 6 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16355,27 +16355,52 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
#[tokio::test]
async fn test_format_function() -> Result<(), CubeError> {
// Test: Basic usage with a single identifier
let result = execute_query("SELECT format('%I', 'column_name') AS formatted_identifier".to_string(), DatabaseProtocol::PostgreSQL).await?;
let result = execute_query(
"SELECT format('%I', 'column_name') AS formatted_identifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("formatted_identifier", result);

// Test: Using multiple identifiers
let result = execute_query("SELECT format('%I, %I', 'table_name', 'column_name') AS formatted_identifiers".to_string(), DatabaseProtocol::PostgreSQL).await?;
let result = execute_query(
"SELECT format('%I, %I', 'table_name', 'column_name') AS formatted_identifiers"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("formatted_identifiers", result);

// Test: Unsupported format specifier
let result = execute_query("SELECT format('%X', 'value') AS unsupported_specifier".to_string(), DatabaseProtocol::PostgreSQL).await;
let result = execute_query(
"SELECT format('%X', 'value') AS unsupported_specifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;
assert!(result.is_err());

// Test: Format string ending with %
let result = execute_query("SELECT format('%', 'value') AS invalid_format".to_string(), DatabaseProtocol::PostgreSQL).await;
let result = execute_query(
"SELECT format('%', 'value') AS invalid_format".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;
assert!(result.is_err());

// Test: Quoting necessary for special characters
let result = execute_query("SELECT format('%I', 'column-name') AS quoted_identifier".to_string(), DatabaseProtocol::PostgreSQL).await?;
let result = execute_query(
"SELECT format('%I', 'column-name') AS quoted_identifier".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("quoted_identifier", result);

// Test: Quoting necessary for reserved keywords
let result = execute_query("SELECT format('%I', 'select') AS quoted_keyword".to_string(), DatabaseProtocol::PostgreSQL).await?;
let result = execute_query(
"SELECT format('%I', 'select') AS quoted_keyword".to_string(),
DatabaseProtocol::PostgreSQL,
)
.await?;
insta::assert_snapshot!("quoted_keyword", result);

Ok(())
Expand Down

0 comments on commit 30128c0

Please sign in to comment.