Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
*/
case class EvalContext(partitionValues: Map[String, String])

/**
* The expression context containing metadata about the operation.
*
* @param collationIdentifier The collation identifier for string comparisons, if applicable.
*/
case class ExprContext(collationIdentifier: Option[String] = None)

/**
* The data types supported by the filtering operations.
*/
Expand Down Expand Up @@ -243,33 +250,38 @@ case class IsNullOp(children: Seq[LeafOp]) extends NonLeafOp with UnaryOp {
* @param children Expected size is 2.
*/

case class EqualOp(children: Seq[LeafOp]) extends NonLeafOp with BinaryOp {
case class EqualOp(children: Seq[LeafOp], exprCtx: Option[ExprContext] = None)
extends NonLeafOp with BinaryOp {
override def validate(forV2: Boolean = false): Unit = validateChildren(children, forV2)

override def eval(ctx: EvalContext): Any = EvalHelper.equal(children, ctx)
}

case class LessThanOp(children: Seq[LeafOp]) extends NonLeafOp with BinaryOp {
case class LessThanOp(children: Seq[LeafOp], exprCtx: Option[ExprContext] = None)
extends NonLeafOp with BinaryOp {
override def validate(forV2: Boolean = false): Unit = validateChildren(children, forV2)

override def eval(ctx: EvalContext): Any = EvalHelper.lessThan(children, ctx)
}

case class LessThanOrEqualOp(children: Seq[LeafOp]) extends NonLeafOp with BinaryOp {
case class LessThanOrEqualOp(children: Seq[LeafOp], exprCtx: Option[ExprContext] = None)
extends NonLeafOp with BinaryOp {
override def validate(forV2: Boolean = false): Unit = validateChildren(children, forV2)

override def eval(ctx: EvalContext): Any =
EvalHelper.lessThan(children, ctx) || EvalHelper.equal(children, ctx)
}

case class GreaterThanOp(children: Seq[LeafOp]) extends NonLeafOp with BinaryOp {
case class GreaterThanOp(children: Seq[LeafOp], exprCtx: Option[ExprContext] = None)
extends NonLeafOp with BinaryOp {
override def validate(forV2: Boolean = false): Unit = validateChildren(children, forV2)

override def eval(ctx: EvalContext): Any =
!EvalHelper.lessThan(children, ctx) && !EvalHelper.equal(children, ctx)
}

case class GreaterThanOrEqualOp(children: Seq[LeafOp]) extends NonLeafOp with BinaryOp {
case class GreaterThanOrEqualOp(children: Seq[LeafOp], exprCtx: Option[ExprContext] = None)
extends NonLeafOp with BinaryOp {
override def validate(forV2: Boolean = false): Unit = validateChildren(children, forV2)

override def eval(ctx: EvalContext): Any = !EvalHelper.lessThan(children, ctx)
Expand Down
81 changes: 73 additions & 8 deletions client/src/main/scala/io/delta/sharing/filters/OpConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package io.delta.sharing.filters

import scala.collection.mutable.ListBuffer

import com.ibm.icu.util.VersionInfo.ICU_VERSION
import org.apache.spark.sql.catalyst.expressions.{
And => SqlAnd,
Attribute => SqlAttribute,
Expand Down Expand Up @@ -93,15 +94,30 @@ object OpConverter {

// Convert comparison operators.
case SqlEqualTo(left, right) =>
EqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right)))
EqualOp(
Seq(convertAsLeaf(left), convertAsLeaf(right)),
extractExprContext(left, right)
)
case SqlLessThan(left, right) =>
LessThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right)))
LessThanOp(
Seq(convertAsLeaf(left), convertAsLeaf(right)),
extractExprContext(left, right)
)
case SqlLessThanOrEqual(left, right) =>
LessThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right)))
LessThanOrEqualOp(
Seq(convertAsLeaf(left), convertAsLeaf(right)),
extractExprContext(left, right)
)
case SqlGreaterThan(left, right) =>
GreaterThanOp(Seq(convertAsLeaf(left), convertAsLeaf(right)))
GreaterThanOp(
Seq(convertAsLeaf(left), convertAsLeaf(right)),
extractExprContext(left, right)
)
case SqlGreaterThanOrEqual(left, right) =>
GreaterThanOrEqualOp(Seq(convertAsLeaf(left), convertAsLeaf(right)))
GreaterThanOrEqualOp(
Seq(convertAsLeaf(left), convertAsLeaf(right)),
extractExprContext(left, right)
)

// Convert null operations.
case SqlIsNull(child) =>
Expand All @@ -118,7 +134,9 @@ object OpConverter {
)
}
val leafOp = convertAsLeaf(value)
list.map(e => EqualOp(Seq(leafOp, convertAsLeaf(e)))) match {
list.map(e =>
EqualOp(Seq(leafOp, convertAsLeaf(e)), extractExprContext(value, e))
) match {
case Seq() =>
throw new IllegalArgumentException("The In predicate must have at least one entry")
case Seq(child) => child
Expand All @@ -131,13 +149,14 @@ object OpConverter {
val rightOp = convertAsLeaf(right)
val leftIsNullOp = IsNullOp(Seq(leftOp))
val rightIsNullOp = IsNullOp(Seq(rightOp))
val exprCtx = extractExprContext(left, right)
// Either both are null, or none is null and they are equal.
OrOp(Seq(
AndOp(Seq(leftIsNullOp, rightIsNullOp)),
AndOp(Seq(
NotOp(Seq(leftIsNullOp)),
NotOp(Seq(rightIsNullOp)),
EqualOp(Seq(leftOp, rightOp))))
EqualOp(Seq(leftOp, rightOp), exprCtx)))
))

// Unsupported expressions.
Expand Down Expand Up @@ -186,7 +205,7 @@ object OpConverter {
case SqlBooleanType => OpDataTypes.BoolType
case SqlIntegerType => OpDataTypes.IntType
case SqlLongType => OpDataTypes.LongType
case SqlStringType => OpDataTypes.StringType
case _: SqlStringType => OpDataTypes.StringType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can revert

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't actually. Without this change this line was only matching the case object of StringType class, but we want it to match it as well as each instance we create for collated types eg. StringType("UTF8_LCASE")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding this as a comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

case SqlDateType => OpDataTypes.DateType
case SqlDoubleType => OpDataTypes.DoubleType
case SqlFloatType => OpDataTypes.FloatType
Expand All @@ -207,4 +226,50 @@ object OpConverter {
case _ => lit.toString
}
}

// Extracts expression context from two expressions, including collation information
// if both are strings with the same collation. This is a generic function that can be
// extended to extract other dimensions of context in the future.
private def extractExprContext(
left: SqlExpression,
right: SqlExpression): Option[ExprContext] = {
val collationId = extractCollationIdentifier(left, right)

// If we have any context information, return an ExprContext
if (collationId.isDefined) {
Some(ExprContext(collationIdentifier = collationId))
} else {
None
}
}

// Extracts collation identifier from two expressions if both are strings
// with the same collation.
private def extractCollationIdentifier(
left: SqlExpression,
right: SqlExpression): Option[String] = {
(left.dataType, right.dataType) match {
case (leftStr: SqlStringType, rightStr: SqlStringType) =>
// Spark needs to make sure to only compare strings of the same collation.
if (leftStr != rightStr) {
throw new IllegalArgumentException(
s"Cannot compare strings with different collations: " +
s"'${leftStr.typeName}' vs '${rightStr.typeName}'"
)
}

val typeName = leftStr.typeName
if (typeName.startsWith("string collate")) {
val collationName = typeName.stripPrefix("string collate").trim
val provider = if (collationName.equalsIgnoreCase("UTF8_LCASE")) "spark" else "icu"
val version = s"${ICU_VERSION.getMajor}.${ICU_VERSION.getMinor}"
Some(s"$provider.$collationName.$version")
} else {
None
}

case _ =>
None
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we guard this change via some config/param?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but since OpConverter does not have a context of SqlConfig we would have to change a whole lot of code and public APIs to pipe this through to here. So I would say that this is not worth doing for this specific change.

}
189 changes: 189 additions & 0 deletions client/src/test/scala/io/delta/sharing/filters/OpConverterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.delta.sharing.filters

import com.ibm.icu.util.VersionInfo.ICU_VERSION
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{
And => SqlAnd,
Expand Down Expand Up @@ -50,6 +51,8 @@ import org.apache.spark.sql.types.{

class OpConverterSuite extends SparkFunSuite {

val icuVersion: String = s"${ICU_VERSION.getMajor}.${ICU_VERSION.getMinor}"

test("equal test") {
val sqlColumn = SqlAttributeReference("userId", SqlIntegerType)()
val sqlLiteral = SqlLiteral(23, SqlIntegerType)
Expand Down Expand Up @@ -353,4 +356,190 @@ class OpConverterSuite extends SparkFunSuite {
convert(Seq.empty)
}.getMessage.contains("The In predicate must have at least one entry"))
}

test("UTF8_BINARY collation test") {
val defaultStringType = SqlStringType
val sqlColumn = SqlAttributeReference("email", defaultStringType)()
val sqlLiteral = SqlLiteral("test@example.com")
val sqlEq = SqlEqualTo(sqlColumn, sqlLiteral)

val op = OpConverter.convert(Seq(sqlEq)).get.asInstanceOf[EqualOp]
op.validate()

val columnOp = op.children(0).asInstanceOf[ColumnOp]
val literalOp = op.children(1).asInstanceOf[LiteralOp]
assert(columnOp.valueType == OpDataTypes.StringType)
assert(literalOp.valueType == OpDataTypes.StringType)

assert(op.exprCtx.isEmpty)
}

test("collated string UNICODE_CI equal test") {
val collatedStringType = SqlStringType("UNICODE_CI")
val sqlColumn = SqlAttributeReference("name", collatedStringType)()
val sqlLiteral = SqlLiteral.create("TestValue", collatedStringType)
val sqlEq = SqlEqualTo(sqlColumn, sqlLiteral)

val op = OpConverter.convert(Seq(sqlEq)).get.asInstanceOf[EqualOp]
op.validate()

val columnOp = op.children(0).asInstanceOf[ColumnOp]
val literalOp = op.children(1).asInstanceOf[LiteralOp]
assert(columnOp.valueType == OpDataTypes.StringType)
assert(literalOp.valueType == OpDataTypes.StringType)

assert(op.exprCtx.isDefined)
assert(op.exprCtx.get.collationIdentifier.isDefined)
val collationId = op.exprCtx.get.collationIdentifier.get
assert(collationId == s"icu.UNICODE_CI.$icuVersion")
}

test("collated string UTF8_LCASE equal test") {
val collatedStringType = SqlStringType("UTF8_LCASE")
val sqlColumn = SqlAttributeReference("name", collatedStringType)()
val sqlLiteral = SqlLiteral.create("TestValue", collatedStringType)
val sqlEq = SqlEqualTo(sqlColumn, sqlLiteral)

val op = OpConverter.convert(Seq(sqlEq)).get.asInstanceOf[EqualOp]
op.validate()

// Verify that valueType is plain string
val columnOp = op.children(0).asInstanceOf[ColumnOp]
val literalOp = op.children(1).asInstanceOf[LiteralOp]
assert(columnOp.valueType == OpDataTypes.StringType)
assert(literalOp.valueType == OpDataTypes.StringType)

// Verify that collationIdentifier is correctly set with spark provider
assert(op.exprCtx.isDefined)
assert(op.exprCtx.get.collationIdentifier.isDefined)
val collationId = op.exprCtx.get.collationIdentifier.get
assert(collationId == s"spark.UTF8_LCASE.$icuVersion")
}

test("collated string with cast test") {
val collatedStringType = SqlStringType("UNICODE_CI")
val sqlColumn = SqlAttributeReference("name", collatedStringType)()
val sqlLiteral = SqlLiteral("TestValue")
// Cast the literal to the collated type
val sqlEq = SqlEqualTo(sqlColumn, SqlCast(sqlLiteral, collatedStringType))

val op = OpConverter.convert(Seq(sqlEq)).get.asInstanceOf[EqualOp]
op.validate()

// Verify that valueType is plain string
val columnOp = op.children(0).asInstanceOf[ColumnOp]
val literalOp = op.children(1).asInstanceOf[LiteralOp]
assert(columnOp.valueType == OpDataTypes.StringType)
assert(literalOp.valueType == OpDataTypes.StringType)

// Verify that collationIdentifier is correctly set
assert(op.exprCtx.isDefined)
assert(op.exprCtx.get.collationIdentifier.isDefined)
val collationId = op.exprCtx.get.collationIdentifier.get
assert(collationId == s"icu.UNICODE_CI.$icuVersion")
}

test("collated string comparison operations test") {
val collatedStringType = SqlStringType("UNICODE_CI")
val sqlColumn = SqlAttributeReference("name", collatedStringType)()
val sqlLiteral = SqlLiteral.create("TestValue", collatedStringType)

val expectedCollationId = s"icu.UNICODE_CI.$icuVersion"

// Test LessThan
val ltOp = OpConverter.convert(Seq(SqlLessThan(sqlColumn, sqlLiteral)))
.get.asInstanceOf[LessThanOp]
assert(ltOp.exprCtx.isDefined)
assert(ltOp.exprCtx.get.collationIdentifier.contains(expectedCollationId))

// Test GreaterThan
val gtOp = OpConverter.convert(Seq(SqlGreaterThan(sqlColumn, sqlLiteral)))
.get.asInstanceOf[GreaterThanOp]
assert(gtOp.exprCtx.isDefined)
assert(gtOp.exprCtx.get.collationIdentifier.contains(expectedCollationId))

// Test LessThanOrEqual
val lteOp = OpConverter.convert(Seq(SqlLessThanOrEqual(sqlColumn, sqlLiteral)))
.get.asInstanceOf[LessThanOrEqualOp]
assert(lteOp.exprCtx.isDefined)
assert(lteOp.exprCtx.get.collationIdentifier.contains(expectedCollationId))

// Test GreaterThanOrEqual
val gteOp = OpConverter.convert(Seq(SqlGreaterThanOrEqual(sqlColumn, sqlLiteral)))
.get.asInstanceOf[GreaterThanOrEqualOp]
assert(gteOp.exprCtx.isDefined)
assert(gteOp.exprCtx.get.collationIdentifier.contains(expectedCollationId))
}

test("collated string In expression test") {
val collatedStringType = SqlStringType("UNICODE_CI")
val sqlColumn = SqlAttributeReference("name", collatedStringType)()
val sqlLiterals = Seq("Value1", "Value2", "Value3").map(v =>
SqlLiteral.create(v, collatedStringType)
)
val sqlIn = SqlIn(sqlColumn, sqlLiterals)

val op = OpConverter.convert(Seq(sqlIn)).get
op.validate()

val orOp = op.asInstanceOf[OrOp]
assert(orOp.children.size == 3)

// Verify that each EqualOp has the correct collationIdentifier
val expectedCollationId = s"icu.UNICODE_CI.$icuVersion"
orOp.children.foreach { child =>
val equalOp = child.asInstanceOf[EqualOp]
assert(equalOp.exprCtx.isDefined)
assert(equalOp.exprCtx.get.collationIdentifier.contains(expectedCollationId))

// Verify that valueType is plain string
val columnOp = equalOp.children(0).asInstanceOf[ColumnOp]
val literalOp = equalOp.children(1).asInstanceOf[LiteralOp]
assert(columnOp.valueType == OpDataTypes.StringType)
assert(literalOp.valueType == OpDataTypes.StringType)
}
}

test("mismatched collations throw error") {
val unicodeCIType = SqlStringType("UNICODE_CI")
val utf8LcaseType = SqlStringType("UTF8_LCASE")

val columnUnicodeCI = SqlAttributeReference("name1", unicodeCIType)()
val columnUtf8Lcase = SqlAttributeReference("name2", utf8LcaseType)()

// Test with EqualTo
val exception1 = intercept[IllegalArgumentException] {
OpConverter.convert(Seq(SqlEqualTo(columnUnicodeCI, columnUtf8Lcase)))
}
assert(exception1.getMessage.contains("Cannot compare strings with different collations"))
assert(exception1.getMessage.contains("UNICODE_CI"))
assert(exception1.getMessage.contains("UTF8_LCASE"))

// Test with LessThan
val exception2 = intercept[IllegalArgumentException] {
OpConverter.convert(Seq(SqlLessThan(columnUnicodeCI, columnUtf8Lcase)))
}
assert(exception2.getMessage.contains("Cannot compare strings with different collations"))

// Test with GreaterThan
val exception3 = intercept[IllegalArgumentException] {
OpConverter.convert(Seq(SqlGreaterThan(columnUnicodeCI, columnUtf8Lcase)))
}
assert(exception3.getMessage.contains("Cannot compare strings with different collations"))
}

test("collated vs non-collated string throws error") {
val collatedType = SqlStringType("UNICODE_CI")
val defaultType = SqlStringType

val columnCollated = SqlAttributeReference("name1", collatedType)()
val columnDefault = SqlAttributeReference("name2", defaultType)()

val exception = intercept[IllegalArgumentException] {
OpConverter.convert(Seq(SqlEqualTo(columnCollated, columnDefault)))
}
assert(exception.getMessage.contains("Cannot compare strings with different collations"))
assert(exception.getMessage.contains("string collate UNICODE_CI"))
assert(exception.getMessage.contains("string"))
}
}
Loading