-
Notifications
You must be signed in to change notification settings - Fork 216
Add support for collated strings in OpConverter #802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
93f0fec
e4e92e3
a1a246a
74b5ab5
0523ee2
fd07671
3568168
0fdfafc
268bbe7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| /* | ||
| * Copyright (2021) The Delta Lake Project Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package io.delta.sharing.filters | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Expression => SqlExpression} | ||
| import org.apache.spark.sql.types.{StringType => SqlStringType} | ||
|
|
||
| object CollationExtractor { | ||
| // Extracts collation identifier from two expressions if both are strings | ||
| // with the same collation. | ||
| def extractCollationIdentifier( | ||
| left: SqlExpression, | ||
| right: SqlExpression): Option[String] = { | ||
| (left.dataType, right.dataType) match { | ||
| case (leftStr: SqlStringType, rightStr: SqlStringType) => | ||
| // Spark needs to make sure to only compare strings of the same collation. | ||
| if (leftStr != rightStr) { | ||
| throw new IllegalArgumentException( | ||
| s"Cannot compare strings with different collations: " + | ||
| s"'${leftStr.typeName}' vs '${rightStr.typeName}'" | ||
| ) | ||
| } | ||
|
|
||
| // The 2.12 client depends on Spark 3.5, which does not support collations. | ||
| // This means we cannot extract the collation identifier. In this case, we | ||
| // should throw an error so the filter is not converted. This avoids applying | ||
| // an incorrect filter and ensures we do not return wrong results. | ||
| validateNoCollations(leftStr.typeName) | ||
| None | ||
|
|
||
| case _ => | ||
| None | ||
| } | ||
| } | ||
|
|
||
| private def validateNoCollations(typeName: String): Unit = { | ||
| if (typeName.startsWith("string collate")) { | ||
| throw new IllegalArgumentException( | ||
| s"Cannot convert operand of unsupported type $typeName" | ||
| ) | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* | ||
| * Copyright (2021) The Delta Lake Project Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package io.delta.sharing.filters | ||
|
|
||
| import com.ibm.icu.util.VersionInfo.ICU_VERSION | ||
| import org.apache.spark.sql.catalyst.expressions.{Expression => SqlExpression} | ||
| import org.apache.spark.sql.types.{StringType => SqlStringType} | ||
|
|
||
| object CollationExtractor { | ||
| // Extracts collation identifier from two expressions if both are strings | ||
| // with the same collation. | ||
| def extractCollationIdentifier( | ||
| left: SqlExpression, | ||
| right: SqlExpression): Option[String] = { | ||
| (left.dataType, right.dataType) match { | ||
| case (leftStr: SqlStringType, rightStr: SqlStringType) => | ||
| // Spark needs to make sure to only compare strings of the same collation. | ||
| if (leftStr != rightStr) { | ||
| throw new IllegalArgumentException( | ||
| s"Cannot compare strings with different collations: " + | ||
| s"'${leftStr.typeName}' vs '${rightStr.typeName}'" | ||
| ) | ||
| } | ||
|
|
||
| val typeName = leftStr.typeName | ||
| if (typeName.startsWith("string collate")) { | ||
| val collationName = typeName.stripPrefix("string collate").trim | ||
| val provider = if (collationName.equalsIgnoreCase("UTF8_LCASE")) "spark" else "icu" | ||
| val version = s"${ICU_VERSION.getMajor}.${ICU_VERSION.getMinor}" | ||
| Some(s"$provider.$collationName.$version") | ||
| } else { | ||
| None | ||
| } | ||
|
|
||
| case _ => | ||
| None | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,15 +93,30 @@ object OpConverter { | |
|
|
||
| // Convert comparison operators. | ||
| case SqlEqualTo(left, right) => | ||
| EqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| EqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlLessThan(left, right) => | ||
| LessThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| LessThanOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlLessThanOrEqual(left, right) => | ||
| LessThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| LessThanOrEqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlGreaterThan(left, right) => | ||
| GreaterThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| GreaterThanOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlGreaterThanOrEqual(left, right) => | ||
| GreaterThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| GreaterThanOrEqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
|
|
||
| // Convert null operations. | ||
| case SqlIsNull(child) => | ||
|
|
@@ -118,7 +133,9 @@ object OpConverter { | |
| ) | ||
| } | ||
| val leafOp = convertAsLeaf(value) | ||
| list.map(e => EqualOp(Seq(leafOp, convertAsLeaf(e)))) match { | ||
| list.map(e => | ||
| EqualOp(Seq(leafOp, convertAsLeaf(e)), extractExprContext(value, e)) | ||
| ) match { | ||
| case Seq() => | ||
| throw new IllegalArgumentException("The In predicate must have at least one entry") | ||
| case Seq(child) => child | ||
|
|
@@ -131,13 +148,14 @@ object OpConverter { | |
| val rightOp = convertAsLeaf(right) | ||
| val leftIsNullOp = IsNullOp(Seq(leftOp)) | ||
| val rightIsNullOp = IsNullOp(Seq(rightOp)) | ||
| val exprCtx = extractExprContext(left, right) | ||
| // Either both are null, or none is null and they are equal. | ||
| OrOp(Seq( | ||
| AndOp(Seq(leftIsNullOp, rightIsNullOp)), | ||
| AndOp(Seq( | ||
| NotOp(Seq(leftIsNullOp)), | ||
| NotOp(Seq(rightIsNullOp)), | ||
| EqualOp(Seq(leftOp, rightOp)))) | ||
| EqualOp(Seq(leftOp, rightOp), exprCtx))) | ||
| )) | ||
|
|
||
| // Unsupported expressions. | ||
|
|
@@ -186,7 +204,9 @@ object OpConverter { | |
| case SqlBooleanType => OpDataTypes.BoolType | ||
| case SqlIntegerType => OpDataTypes.IntType | ||
| case SqlLongType => OpDataTypes.LongType | ||
| case SqlStringType => OpDataTypes.StringType | ||
| // We need to match all string types (with different collations), | ||
| // and not just the case object which is UTF8_BINARY collated. | ||
| case _: SqlStringType => OpDataTypes.StringType | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can revert
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't actually. Without this change this line was only matching the case object of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding this as a comment.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
| case SqlDateType => OpDataTypes.DateType | ||
| case SqlDoubleType => OpDataTypes.DoubleType | ||
| case SqlFloatType => OpDataTypes.FloatType | ||
|
|
@@ -207,4 +227,21 @@ object OpConverter { | |
| case _ => lit.toString | ||
| } | ||
| } | ||
|
|
||
| // Extracts expression context from two expressions, including collation information | ||
| // if both are strings with the same collation. This is a generic function that can be | ||
| // extended to extract other dimensions of context in the future. | ||
| private def extractExprContext( | ||
| left: SqlExpression, | ||
| right: SqlExpression): Option[ExprContext] = { | ||
| val collationId = CollationExtractor.extractCollationIdentifier(left, right) | ||
|
|
||
| // If we have any context information, return an ExprContext | ||
| if (collationId.isDefined) { | ||
| Some(ExprContext(collationIdentifier = collationId)) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* | ||
| * Copyright (2021) The Delta Lake Project Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package io.delta.sharing.filters | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.expressions.{ | ||
| AttributeReference => SqlAttributeReference, | ||
| EqualTo => SqlEqualTo, | ||
| Literal => SqlLiteral | ||
| } | ||
| import org.apache.spark.sql.types.{ | ||
| StringType => SqlStringType | ||
| } | ||
|
|
||
| class OpConverterCollationSuite extends SparkFunSuite { | ||
|
|
||
| test("UTF8_BINARY collation test") { | ||
| val defaultStringType = SqlStringType | ||
| val sqlColumn = SqlAttributeReference("email", defaultStringType)() | ||
| val sqlLiteral = SqlLiteral("[email protected]") | ||
| val sqlEq = SqlEqualTo(sqlColumn, sqlLiteral) | ||
|
|
||
| val op = OpConverter.convert(Seq(sqlEq)).get.asInstanceOf[EqualOp] | ||
| op.validate() | ||
|
|
||
| val columnOp = op.children(0).asInstanceOf[ColumnOp] | ||
| val literalOp = op.children(1).asInstanceOf[LiteralOp] | ||
| assert(columnOp.valueType == OpDataTypes.StringType) | ||
| assert(literalOp.valueType == OpDataTypes.StringType) | ||
|
|
||
| // UTF8_BINARY (default) should work fine on Scala 2.12 | ||
| assert(op.exprCtx.isEmpty) | ||
| } | ||
|
|
||
| // Note: Collated string types (SqlStringType with collation parameter) don't exist in | ||
| // Spark 3.5 (Scala 2.12). They were added in Spark 4.0 (Scala 2.13). | ||
| // Therefore, we cannot test collation behavior in the Scala 2.12 version of this suite. | ||
| // All collation-specific tests are in the Scala 2.13 version of this suite. | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.