diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index d6fa7f58d61cf..2a92be3619eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -22,12 +22,12 @@ import java.util.{Collections, Locale} import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.CurrentUserContext import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, TimeTravelSpec} import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ @@ -598,10 +598,27 @@ private[sql] object CatalogV2Util { // data unchanged and let the data reader to return "exist default" for missing // columns. val existingDefault = Literal(default.getValue.value(), default.getValue.dataType()).sql - f.withExistenceDefaultValue(existingDefault).withCurrentDefaultValue(default.getSql) + f.withExistenceDefaultValue(existingDefault).withCurrentDefaultValue(toSql(defaultValue)) }.getOrElse(f) } + private def toSql(defaultValue: DefaultValue): String = { + if (defaultValue.getExpression != null) { + V2ExpressionUtils.toCatalyst(defaultValue.getExpression) match { + case Some(catalystExpr) => + catalystExpr.sql + case None if defaultValue.getSql != null => + defaultValue.getSql + case _ => + throw SparkException.internalError( + s"Can't generate SQL for $defaultValue. The connector expression couldn't be " + + "converted to Catalyst and there is no provided SQL representation.") + } + } else { + defaultValue.getSql + } + } + /** * Converts a StructType to DS v2 columns, which decodes the StructField metadata to v2 column * comment and default value or generation expression. This is mainly used to generate DS v2 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index d6d397b94648d..4b4df98d5a54e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -49,7 +49,7 @@ import org.apache.spark.util.ArrayImplicits._ */ abstract class InMemoryBaseTable( val name: String, - val schema: StructType, + override val columns: Array[Column], override val partitioning: Array[Transform], override val properties: util.Map[String, String], val distribution: Distribution = Distributions.unspecified(), @@ -88,6 +88,8 @@ abstract class InMemoryBaseTable( } } + override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns) + // purposely exposes a metadata column that conflicts with a data column in some tests override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) private val metadataColumnNames = metadataColumns.map(_.name).toSet diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index c822e27ceb58e..aeb807768b076 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -38,8 +38,13 @@ class InMemoryRowLevelOperationTable( partitioning: Array[Transform], properties: util.Map[String, String], constraints: Array[Constraint] = Array.empty) - extends InMemoryTable(name, schema, partitioning, properties, constraints) - with SupportsRowLevelOperations { + extends InMemoryTable( + name, + CatalogV2Util.structTypeToV2Columns(schema), + partitioning, + properties, + constraints) + with SupportsRowLevelOperations { private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name) private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 50e2449623e5c..7d042ff321eac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.ArrayImplicits._ */ class InMemoryTable( name: String, - schema: StructType, + columns: Array[Column], override val partitioning: Array[Transform], override val properties: util.Map[String, String], override val constraints: Array[Constraint] = Array.empty, @@ -43,10 +43,22 @@ class InMemoryTable( advisoryPartitionSize: Option[Long] = None, isDistributionStrictlyRequired: Boolean = true, override val numRowsPerSplit: Int = Int.MaxValue) - extends InMemoryBaseTable(name, schema, partitioning, properties, distribution, + extends InMemoryBaseTable(name, columns, partitioning, properties, distribution, ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired, numRowsPerSplit) with SupportsDelete { + def this( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) = { + this( + name, + CatalogV2Util.structTypeToV2Columns(schema), + partitioning, + properties) + } + override def canDeleteWhere(filters: Array[Filter]): Boolean = { InMemoryTable.supportsFilters(filters) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 7d64cad2bb102..f2d427d975990 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -118,7 +118,6 @@ class BasicInMemoryTableCatalog extends TableCatalog { distributionStrictlyRequired: Boolean = true, numRowsPerSplit: Int = Int.MaxValue): Table = { // scalastyle:on argcount - val schema = CatalogV2Util.v2ColumnsToStructType(columns) if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) } @@ -126,7 +125,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) val tableName = s"$name.${ident.quoted}" - val table = new InMemoryTable(tableName, schema, partitions, properties, constraints, + val table = new InMemoryTable(tableName, columns, partitions, properties, constraints, distribution, ordering, requiredNumPartitions, advisoryPartitionSize, distributionStrictlyRequired, numRowsPerSplit) tables.put(ident, table) @@ -152,7 +151,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { val newTable = new InMemoryTable( name = table.name, - schema = schema, + columns = CatalogV2Util.structTypeToV2Columns(schema), partitioning = finalPartitioning, properties = properties, constraints = constraints) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala index 9b7a90774f91c..7c962ca1678f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -31,10 +31,10 @@ import org.apache.spark.util.ArrayImplicits._ class InMemoryTableWithV2Filter( name: String, - schema: StructType, + columns: Array[Column], partitioning: Array[Transform], properties: util.Map[String, String]) - extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 { + extends InMemoryBaseTable(name, columns, partitioning, properties) with SupportsDeleteV2 { override def canDeleteWhere(predicates: Array[Predicate]): Boolean = { InMemoryTableWithV2Filter.supportsPredicates(predicates) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala index 861badd390798..ef2f5e26f0029 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala @@ -37,8 +37,7 @@ class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog { InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) val tableName = s"$name.${ident.quoted}" - val schema = CatalogV2Util.v2ColumnsToStructType(columns) - val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties) + val table = new InMemoryTableWithV2Filter(tableName, columns, partitions, properties) tables.put(ident, table) namespaces.putIfAbsent(ident.namespace.toList, Map()) table diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index b6e27aea16c73..8fa18c690e0ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.connector import java.util.Collections +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, Identifier, InMemoryTableCatalog} -import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, GeneralScalarExpression, LiteralValue, Transform} +import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, Identifier, InMemoryTableCatalog, TableInfo} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, Cast => V2Cast, GeneralScalarExpression, LiteralValue, Transform} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan import org.apache.spark.sql.execution.datasources.v2.{CreateTableExec, DataSourceV2Relation, ReplaceTableExec} @@ -51,6 +52,11 @@ class DataSourceV2DataFrameSuite override protected val catalogAndNamespace: String = "testcat.ns1.ns2.tbls" override protected val v2Format: String = classOf[FakeV2Provider].getName + protected def catalog(name: String): InMemoryTableCatalog = { + val catalog = spark.sessionState.catalogManager.catalog(name) + catalog.asInstanceOf[InMemoryTableCatalog] + } + override def verifyTable(tableName: String, expected: DataFrame): Unit = { checkAnswer(spark.table(tableName), expected) } @@ -483,6 +489,82 @@ class DataSourceV2DataFrameSuite } } + test("write with supported expression-based default values") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + val columns = Array( + Column.create("c1", IntegerType), + Column.create( + "c2", + IntegerType, + false, /* not nullable */ + null, /* no comment */ + new ColumnDefaultValue( + new GeneralScalarExpression( + "+", + Array(LiteralValue(100, IntegerType), LiteralValue(23, IntegerType))), + LiteralValue(123, IntegerType)), + "{}")) + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog("testcat").createTable(Identifier.of(Array("ns1", "ns2"), "tbl"), tableInfo) + val df = Seq(1, 2, 3).toDF("c1") + df.writeTo(tableName).append() + checkAnswer( + spark.table(tableName), + Seq(Row(1, 123), Row(2, 123), Row(3, 123))) + } + } + + test("write with unsupported expression-based default values (no SQL provided)") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + val columns = Array( + Column.create("c1", IntegerType), + Column.create( + "c2", + IntegerType, + false, /* not nullable */ + null, /* no comment */ + new ColumnDefaultValue( + ApplyTransform( + "UNKNOWN_TRANSFORM", + Seq(LiteralValue(100, IntegerType), LiteralValue(23, IntegerType))), + LiteralValue(123, IntegerType)), + "{}")) + val e = intercept[SparkException] { + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog("testcat").createTable(Identifier.of(Array("ns1", "ns2"), "tbl"), tableInfo) + val df = Seq(1, 2, 3).toDF("c1") + df.writeTo(tableName).append() + } + assert(e.getMessage.contains("connector expression couldn't be converted to Catalyst")) + } + } + + test("write with unsupported expression-based default values (with SQL provided)") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + val columns = Array( + Column.create("c1", IntegerType), + Column.create( + "c2", + IntegerType, + false, /* not nullable */ + null, /* no comment */ + new ColumnDefaultValue( + "100 + 23", + ApplyTransform( + "INVALID_TRANSFORM", + Seq(LiteralValue(100, IntegerType), LiteralValue(23, IntegerType))), + LiteralValue(123, IntegerType)), + "{}")) + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog("testcat").createTable(Identifier.of(Array("ns1", "ns2"), "tbl"), tableInfo) + val df = Seq(1, 2, 3).toDF("c1") + df.writeTo(tableName).append() + } + } + private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = { val qe = withQueryExecutionsCaptured(spark) { func