Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,14 @@
],
"sqlState" : "56000"
},
"CHECK_CONSTRAINT_VIOLATION" : {
"message" : [
"CHECK constraint <constraintName> <expression> violated by row with values:",
"<values>",
""
],
"sqlState" : "23001"
},
"CIRCULAR_CLASS_REFERENCE" : {
"message" : [
"Cannot have circular references in class, but got the circular reference of class <t>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveOutputRelation ::
new ResolveTableConstraint(catalogManager) ::
new ResolveDataFrameDropColumns(catalogManager) ::
new ResolveSetVariable(catalogManager) ::
ExtractWindowExpressions ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.constraints.Check
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(COMMAND), ruleId) {
case v2Write: V2WriteCommand
if v2Write.table.resolved && v2Write.query.resolved &&
!v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved =>
v2Write.table match {
case r: DataSourceV2Relation
if r.table.constraints() != null && r.table.constraints().nonEmpty =>
val checks = r.table.constraints().collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we reject others? What if we get back a PK that must be enforced?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PK can't be enforced. There will be a parser error

case c: Check => c
}
val checkInvariants = checks.map { c =>
val parsed =
catalogManager.v1SessionCatalog.parser.parseExpression(c.predicateSql())
val columnExtractors = mutable.Map[String, Expression]()
buildColumnExtractors(parsed, columnExtractors)
CheckInvariant(parsed, columnExtractors.toSeq, c.name(), c.predicateSql())
}.toSeq
v2Write.withNewQuery(Validate(checkInvariants, v2Write.query))
case _ =>
v2Write
}
}

private def buildColumnExtractors(
expr: Expression,
columnExtractors: mutable.Map[String, Expression]): Unit = {
expr match {
case u: UnresolvedExtractValue =>
// When extracting a value from a Map or Array type, we display only the specific extracted
// value rather than the entire Map or Array structure for clarity and readability.
columnExtractors(u.sql) = u
case u: UnresolvedAttribute =>
columnExtractors(u.name) = u

case other =>
other.children.foreach(buildColumnExtractors(_, columnExtractors))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.UUID

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, NullType}

trait TableConstraint extends Expression with Unevaluable {
/** Convert to a data source v2 constraint */
Expand Down Expand Up @@ -259,3 +263,94 @@ case class ForeignKeyConstraint(
copy(userProvidedCharacteristic = c)
}
}

/**
* An expression that validates a specific invariant on a column, before writing into table.
*
* @param child The fully resolved expression to be evaluated to check the constraint.
* @param columnExtractors Extractors for each referenced column. Used to generate readable errors.
* @param constraintName The name of the constraint.
* @param predicateSql The SQL representation of the constraint.
*/
case class CheckInvariant(
child: Expression,
columnExtractors: Seq[(String, Expression)],
constraintName: String,
predicateSql: String)
extends Expression with NonSQLExpression {

override def children: Seq[Expression] = child +: columnExtractors.map(_._2)
override def dataType: DataType = NullType
override def foldable: Boolean = false
override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (result == false) {
val values = columnExtractors.map {
case (column, extractor) => column -> extractor.eval(input)
}.toMap
throw QueryExecutionErrors.checkViolation(constraintName, predicateSql, values)
}
null
}

/**
* Generate the code to extract values for the columns referenced in a violated CHECK constraint.
* We build parallel lists of full column names and their extracted values in the row which
* violates the constraint, to be passed to the [[InvariantViolationException]] constructor
* in [[generateExpressionValidationCode()]].
*
* Note that this code is a bit expensive, so it shouldn't be run until we already
* know the constraint has been violated.
*/
private def generateColumnValuesCode(
colList: String, valList: String, ctx: CodegenContext): Block = {
val start =
code"""
|java.util.List<String> $colList = new java.util.ArrayList<String>();
|java.util.List<Object> $valList = new java.util.ArrayList<Object>();
|""".stripMargin
columnExtractors.map {
case (name, extractor) =>
val colValue = extractor.genCode(ctx)
code"""
|$colList.add("$name");
|${colValue.code}
|if (${colValue.isNull}) {
| $valList.add(null);
|} else {
| $valList.add(${colValue.value});
|}
|""".stripMargin
}.fold(start)(_ + _)
}

private def generateExpressionValidationCode(ctx: CodegenContext): Block = {
val elementValue = child.genCode(ctx)
val colListName = ctx.freshName("colList")
val valListName = ctx.freshName("valList")
val ret = code"""${elementValue.code}
|
|if (!${elementValue.isNull} && ${elementValue.value} == false) {
| ${generateColumnValuesCode(colListName, valListName, ctx)}
| throw org.apache.spark.sql.errors.QueryExecutionErrors.checkViolationJava(
| "$constraintName", "$predicateSql", $colListName, $valListName);
|}
""".stripMargin
ret
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val code = generateExpressionValidationCode(ctx)
ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType))
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(
child = newChildren.head,
columnExtractors = columnExtractors.map(_._1).zip(newChildren.tail)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1664,3 +1664,16 @@ case class Call(
override protected def withNewChildInternal(newChild: LogicalPlan): Call =
copy(procedure = newChild)
}

case class Validate(
conditions: Seq[CheckInvariant],
child: LogicalPlan) extends UnaryNode {

assert(conditions.nonEmpty, "CheckData must have at least one condition")

override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
copy(child = newChild)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.expressions.EliminatePipeOperators" ::
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" ::
"org.apache.spark.sql.catalyst.analysis.ResolveTableConstraint" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.time.DateTimeException
import java.util.Locale
import java.util.concurrent.TimeoutException

import scala.jdk.CollectionConverters._

import com.fasterxml.jackson.core.{JsonParser, JsonToken}
import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path}
import org.apache.hadoop.fs.permission.FsPermission
Expand Down Expand Up @@ -2993,4 +2995,35 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
def notAbsolutePathError(path: Path): SparkException = {
SparkException.internalError(s"$path is not absolute path.")
}

// Throws a SparkRuntimeException when a CHECK constraint is violated, including details of the
// violation.
def checkViolation(
constraintName: String,
sqlStr: String,
values: Map[String, Any]): SparkRuntimeException = {
// Sort by the column name to generate consistent error messages in Scala 2.12 and 2.13.
val valueLines = values.toSeq.sortBy(_._1).map {
case (column, value) =>
s" - $column : $value"
}.mkString("\n")
new SparkRuntimeException(
errorClass = "CHECK_CONSTRAINT_VIOLATION",
messageParameters = Map(
"constraintName" -> constraintName,
"expression" -> sqlStr,
"values" -> valueLines
)
)
}

// Throws a SparkRuntimeException when a CHECK constraint is violated, including details of the
// violation. This is a Java-friendly version of the above method.
def checkViolationJava(
constraintName: String,
sqlStr: String,
columns: java.util.List[String],
values: java.util.List[Any]): SparkRuntimeException = {
checkViolation(constraintName, sqlStr, columns.asScala.zip(values.asScala).toMap)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.IntegerType

class ConstraintExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private val boundRef = BoundReference(0, IntegerType, nullable = true)
private val expr =
CheckInvariant(GreaterThan(boundRef, Literal(0)), Seq(("a", boundRef)), "c1", "a > 0")

def expectedMessage(value: String): String =
s"""|[CHECK_CONSTRAINT_VIOLATION] CHECK constraint c1 a > 0 violated by row with values:
| - a : $value
| SQLSTATE: 23001""".stripMargin

test("CheckInvariant: returns null if column 'a' > 0") {
checkEvaluation(expr, null, InternalRow(1))
}

test("CheckInvariant: return null if column 'a' is null") {
checkEvaluation(expr, null, InternalRow(null))
}

test("CheckInvariant: throws exception if column 'a' <= 0") {
checkExceptionInExpression[SparkRuntimeException](
expr, InternalRow(-1), expectedMessage("-1"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case c: Call =>
ExplainOnlySparkPlan(c) :: Nil

case c: Validate =>
ValidateExec(planLater(c.child), c.conditions) :: Nil

case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, CheckInvariant, SortOrder, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}

case class ValidateExec(
child: SparkPlan,
constraints: Seq[CheckInvariant]) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output

override protected def doExecute(): RDD[InternalRow] = {
if (constraints.isEmpty) return child.execute()

child.execute().mapPartitionsInternal { rows =>
val assertions = UnsafeProjection.create(constraints, child.output)
rows.map { row =>
assertions(row)
row
}
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): ValidateExec =
copy(child = newChild)
}
Loading