-
Notifications
You must be signed in to change notification settings - Fork 216
Add support for collated strings in OpConverter #802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
93f0fec
e4e92e3
a1a246a
74b5ab5
0523ee2
fd07671
3568168
0fdfafc
268bbe7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ package io.delta.sharing.filters | |
|
|
||
| import scala.collection.mutable.ListBuffer | ||
|
|
||
| import com.ibm.icu.util.VersionInfo.ICU_VERSION | ||
| import org.apache.spark.sql.catalyst.expressions.{ | ||
| And => SqlAnd, | ||
| Attribute => SqlAttribute, | ||
|
|
@@ -93,15 +94,30 @@ object OpConverter { | |
|
|
||
| // Convert comparison operators. | ||
| case SqlEqualTo(left, right) => | ||
| EqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| EqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlLessThan(left, right) => | ||
| LessThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| LessThanOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlLessThanOrEqual(left, right) => | ||
| LessThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| LessThanOrEqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlGreaterThan(left, right) => | ||
| GreaterThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| GreaterThanOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
| case SqlGreaterThanOrEqual(left, right) => | ||
| GreaterThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right))) | ||
| GreaterThanOrEqualOp( | ||
| Seq(convertAsLeaf(left), convertAsLeaf(right)), | ||
| extractExprContext(left, right) | ||
| ) | ||
|
|
||
| // Convert null operations. | ||
| case SqlIsNull(child) => | ||
|
|
@@ -118,7 +134,9 @@ object OpConverter { | |
| ) | ||
| } | ||
| val leafOp = convertAsLeaf(value) | ||
| list.map(e => EqualOp(Seq(leafOp, convertAsLeaf(e)))) match { | ||
| list.map(e => | ||
| EqualOp(Seq(leafOp, convertAsLeaf(e)), extractExprContext(value, e)) | ||
| ) match { | ||
| case Seq() => | ||
| throw new IllegalArgumentException("The In predicate must have at least one entry") | ||
| case Seq(child) => child | ||
|
|
@@ -131,13 +149,14 @@ object OpConverter { | |
| val rightOp = convertAsLeaf(right) | ||
| val leftIsNullOp = IsNullOp(Seq(leftOp)) | ||
| val rightIsNullOp = IsNullOp(Seq(rightOp)) | ||
| val exprCtx = extractExprContext(left, right) | ||
| // Either both are null, or none is null and they are equal. | ||
| OrOp(Seq( | ||
| AndOp(Seq(leftIsNullOp, rightIsNullOp)), | ||
| AndOp(Seq( | ||
| NotOp(Seq(leftIsNullOp)), | ||
| NotOp(Seq(rightIsNullOp)), | ||
| EqualOp(Seq(leftOp, rightOp)))) | ||
| EqualOp(Seq(leftOp, rightOp), exprCtx))) | ||
| )) | ||
|
|
||
| // Unsupported expressions. | ||
|
|
@@ -186,7 +205,7 @@ object OpConverter { | |
| case SqlBooleanType => OpDataTypes.BoolType | ||
| case SqlIntegerType => OpDataTypes.IntType | ||
| case SqlLongType => OpDataTypes.LongType | ||
| case SqlStringType => OpDataTypes.StringType | ||
| case _: SqlStringType => OpDataTypes.StringType | ||
| case SqlDateType => OpDataTypes.DateType | ||
| case SqlDoubleType => OpDataTypes.DoubleType | ||
| case SqlFloatType => OpDataTypes.FloatType | ||
|
|
@@ -207,4 +226,50 @@ object OpConverter { | |
| case _ => lit.toString | ||
| } | ||
| } | ||
|
|
||
| // Extracts expression context from two expressions, including collation information | ||
| // if both are strings with the same collation. This is a generic function that can be | ||
| // extended to extract other dimensions of context in the future. | ||
| private def extractExprContext( | ||
| left: SqlExpression, | ||
| right: SqlExpression): Option[ExprContext] = { | ||
| val collationId = extractCollationIdentifier(left, right) | ||
|
|
||
| // If we have any context information, return an ExprContext | ||
| if (collationId.isDefined) { | ||
| Some(ExprContext(collationIdentifier = collationId)) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
|
|
||
| // Extracts collation identifier from two expressions if both are strings | ||
| // with the same collation. | ||
| private def extractCollationIdentifier( | ||
| left: SqlExpression, | ||
| right: SqlExpression): Option[String] = { | ||
| (left.dataType, right.dataType) match { | ||
| case (leftStr: SqlStringType, rightStr: SqlStringType) => | ||
| // Spark needs to make sure to only compare strings of the same collation. | ||
| if (leftStr != rightStr) { | ||
| throw new IllegalArgumentException( | ||
| s"Cannot compare strings with different collations: " + | ||
| s"'${leftStr.typeName}' vs '${rightStr.typeName}'" | ||
| ) | ||
| } | ||
|
|
||
| val typeName = leftStr.typeName | ||
| if (typeName.startsWith("string collate")) { | ||
| val collationName = typeName.stripPrefix("string collate").trim | ||
| val provider = if (collationName.equalsIgnoreCase("UTF8_LCASE")) "spark" else "icu" | ||
| val version = s"${ICU_VERSION.getMajor}.${ICU_VERSION.getMinor}" | ||
| Some(s"$provider.$collationName.$version") | ||
| } else { | ||
| None | ||
| } | ||
|
|
||
| case _ => | ||
| None | ||
| } | ||
| } | ||
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't actually. Without this change this line was only matching the case object of
StringTypeclass, but we want it to match it as well as each instance we create for collated types eg.StringType("UTF8_LCASE")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding this as a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!