Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
* @since 1.3.0
*/
@Stable
sealed class Metadata private[types] (private[types] val map: Map[String, Any])
sealed class Metadata private[types] (
private[types] val map: Map[String, Any],
@transient private[types] val runtimeMap: Map[String, Any])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am adding a runtime-only map of configs that is not serialized or exposed to the user. It will allow me to store alternative in-memory representations for certain configs. In particular, it would allow me to store SQL as well as the expression itself for default values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below is an example of how it is being used:

val existsDefault = extractExistsDefault(default)
val (sql, expr) = extractCurrentDefault(default)
val newMetadata = new MetadataBuilder()
  .withMetadata(f.metadata)
  .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, existsDefault)
  .putExpression(CURRENT_DEFAULT_COLUMN_METADATA_KEY, sql, expr)
  .build()
f.copy(metadata = newMetadata)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows me to use the expression directly without generating/parsing SQL for it.

extends Serializable {

/** No-arg constructor for kryo. */
protected def this() = this(null)
protected def this() = this(null, null)

/** Tests whether this Metadata contains a binding for a key. */
def contains(key: String): Boolean = map.contains(key)
Expand Down Expand Up @@ -120,6 +122,12 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
map(key).asInstanceOf[T]
}

private[sql] def getExpression[E](key: String): (String, Option[E]) = {
val sql = getString(key)
val expr = Option(runtimeMap).flatMap(_.get(key).map(_.asInstanceOf[E]))
sql -> expr
}

private[sql] def jsonValue: JValue = Metadata.toJsonValue(this)
}

Expand All @@ -129,7 +137,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
@Stable
object Metadata {

private[this] val _empty = new Metadata(Map.empty)
private[this] val _empty = new Metadata(Map.empty, Map.empty)

/** Returns an empty Metadata. */
def empty: Metadata = _empty
Expand Down Expand Up @@ -248,13 +256,17 @@ object Metadata {
class MetadataBuilder {

private val map: mutable.Map[String, Any] = mutable.Map.empty
private val runtimeMap: mutable.Map[String, Any] = mutable.Map.empty

/** Returns the immutable version of this map. Used for java interop. */
protected def getMap = map.toMap

/** Include the content of an existing [[Metadata]] instance. */
def withMetadata(metadata: Metadata): this.type = {
map ++= metadata.map
if (metadata.runtimeMap != null) {
runtimeMap ++= metadata.runtimeMap
}
this
}

Expand Down Expand Up @@ -293,16 +305,23 @@ class MetadataBuilder {

/** Builds the [[Metadata]] instance. */
def build(): Metadata = {
new Metadata(map.toMap)
new Metadata(map.toMap, runtimeMap.toMap)
}

private def put(key: String, value: Any): this.type = {
map.put(key, value)
this
}

private[sql] def putExpression[E](key: String, sql: String, expr: Option[E]): this.type = {
map.put(key, sql)
expr.foreach(runtimeMap.put(key, _))
this
}

def remove(key: String): this.type = {
map.remove(key)
runtimeMap.remove(key)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,12 @@ object ResolveDefaultColumns extends QueryErrorsBase
field: StructField,
statementType: String,
metadataKey: String = CURRENT_DEFAULT_COLUMN_METADATA_KEY): Expression = {
analyze(field.name, field.dataType, field.metadata.getString(metadataKey), statementType)
field.metadata.getExpression[Expression](metadataKey) match {
case (sql, Some(expr)) =>
analyze(field.name, field.dataType, expr, sql, statementType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this branch, sql will be only used for error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

case (sql, _) =>
analyze(field.name, field.dataType, sql, statementType)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import java.util.{Collections, Locale}

import scala.jdk.CollectionConverters._

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CurrentUserContext
import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, TimeTravelSpec}
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
Expand Down Expand Up @@ -597,11 +597,31 @@ private[sql] object CatalogV2Util {
// Note: the back-fill here is a logical concept. The data source can keep the existing
// data unchanged and let the data reader to return "exist default" for missing
// columns.
val existingDefault = Literal(default.getValue.value(), default.getValue.dataType()).sql
f.withExistenceDefaultValue(existingDefault).withCurrentDefaultValue(default.getSql)
val existsDefault = extractExistsDefault(default)
val (sql, expr) = extractCurrentDefault(default)
val newMetadata = new MetadataBuilder()
.withMetadata(f.metadata)
.putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, existsDefault)
.putExpression(CURRENT_DEFAULT_COLUMN_METADATA_KEY, sql, expr)
.build()
f.copy(metadata = newMetadata)
}.getOrElse(f)
}

private def extractExistsDefault(default: ColumnDefaultValue): String = {
Literal(default.getValue.value, default.getValue.dataType).sql
}

private def extractCurrentDefault(default: ColumnDefaultValue): (String, Option[Expression]) = {
val expr = Option(default.getExpression).flatMap(V2ExpressionUtils.toCatalyst)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[doubt] presently toCatalyst doesn't handle connector scalar udf's is the plan to enhance this in future ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally yes, but I am not sure we would want to allow them in default values.

val sql = Option(default.getSql).orElse(expr.map(_.sql)).getOrElse {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[doubt] my understanding was .sql is not reliable (based on discussion here), wondering if this could lead to users using getMap or map.get(key), directly and extracting the SQL from the map, skip actually checking if there is an expression for it and one should use that instead ? essentially if there i an entry in the runtimeMap should we let the map.get fail ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we mean these properties to be accessible by users. We do that for built-in tables but there are proper DSv2 APIs for this like ColumnDefaultValue where the expression always takes precedence. SQL is informational in this case.

throw SparkException.internalError(
s"Can't generate SQL for $default. The connector expression couldn't be " +
"converted to Catalyst and there is no provided SQL representation.")
}
(sql, expr)
}

/**
* Converts a StructType to DS v2 columns, which decodes the StructField metadata to v2 column
* comment and default value or generation expression. This is mainly used to generate DS v2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
abstract class InMemoryBaseTable(
val name: String,
val schema: StructType,
override val columns: Array[Column],
override val partitioning: Array[Transform],
override val properties: util.Map[String, String],
override val constraints: Array[Constraint] = Array.empty,
Expand Down Expand Up @@ -114,6 +114,8 @@ abstract class InMemoryBaseTable(
}
}

override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns)

// purposely exposes a metadata column that conflicts with a data column in some tests
override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
private val metadataColumnNames = metadataColumns.map(_.name).toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ class InMemoryRowLevelOperationTable(
partitioning: Array[Transform],
properties: util.Map[String, String],
constraints: Array[Constraint] = Array.empty)
extends InMemoryTable(name, schema, partitioning, properties, constraints)
with SupportsRowLevelOperations {
extends InMemoryTable(
name,
CatalogV2Util.structTypeToV2Columns(schema),
partitioning,
properties,
constraints)
with SupportsRowLevelOperations {

private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
class InMemoryTable(
name: String,
schema: StructType,
columns: Array[Column],
override val partitioning: Array[Transform],
override val properties: util.Map[String, String],
override val constraints: Array[Constraint] = Array.empty,
Expand All @@ -43,10 +43,22 @@ class InMemoryTable(
advisoryPartitionSize: Option[Long] = None,
isDistributionStrictlyRequired: Boolean = true,
override val numRowsPerSplit: Int = Int.MaxValue)
extends InMemoryBaseTable(name, schema, partitioning, properties, constraints, distribution,
extends InMemoryBaseTable(name, columns, partitioning, properties, constraints, distribution,
ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired,
numRowsPerSplit) with SupportsDelete {

def this(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]) = {
this(
name,
CatalogV2Util.structTypeToV2Columns(schema),
partitioning,
properties)
}

override def canDeleteWhere(filters: Array[Filter]): Boolean = {
InMemoryTable.supportsFilters(filters)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,14 @@ class BasicInMemoryTableCatalog extends TableCatalog {
distributionStrictlyRequired: Boolean = true,
numRowsPerSplit: Int = Int.MaxValue): Table = {
// scalastyle:on argcount
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val tableName = s"$name.${ident.quoted}"
val table = new InMemoryTable(tableName, schema, partitions, properties, constraints,
val table = new InMemoryTable(tableName, columns, partitions, properties, constraints,
distribution, ordering, requiredNumPartitions, advisoryPartitionSize,
distributionStrictlyRequired, numRowsPerSplit)
tables.put(ident, table)
Expand Down Expand Up @@ -154,7 +153,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
val currentVersion = table.currentVersion()
val newTable = new InMemoryTable(
name = table.name,
schema = schema,
columns = CatalogV2Util.structTypeToV2Columns(schema),
partitioning = finalPartitioning,
properties = properties,
constraints = constraints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ import org.apache.spark.util.ArrayImplicits._

class InMemoryTableWithV2Filter(
name: String,
schema: StructType,
columns: Array[Column],
partitioning: Array[Transform],
properties: util.Map[String, String])
extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 {
extends InMemoryBaseTable(name, columns, partitioning, properties) with SupportsDeleteV2 {

override def canDeleteWhere(predicates: Array[Predicate]): Boolean = {
InMemoryTableWithV2Filter.supportsPredicates(predicates)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val tableName = s"$name.${ident.quoted}"
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties)
val table = new InMemoryTableWithV2Filter(tableName, columns, partitions, properties)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.types

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.{Add, Expression, Literal}

class MetadataSuite extends SparkFunSuite {
test("String Metadata") {
Expand Down Expand Up @@ -76,4 +78,79 @@ class MetadataSuite extends SparkFunSuite {
assert(meta === Metadata.fromJson(meta.json))
intercept[NoSuchElementException](meta.getLong("no_such_key"))
}

test("Kryo serialization for expressions") {
val conf = new SparkConf()
val serializer = new KryoSerializer(conf).newInstance()
checkMetadataExpressions(serializer)
}

test("Java serialization for expressions") {
val conf = new SparkConf()
val serializer = new JavaSerializer(conf).newInstance()
checkMetadataExpressions(serializer)
}

test("JSON representation with expressions") {
val meta = new MetadataBuilder()
.putString("key", "value")
.putExpression("expr", "1 + 3", Some(Add(Literal(1), Literal(3))))
.build()
assert(meta.json == """{"expr":"1 + 3","key":"value"}""")
}

test("equals and hashCode with expressions") {
val meta1 = new MetadataBuilder()
.putString("key", "value")
.putExpression("expr", "1 + 2", Some(Add(Literal(1), Literal(2))))
.build()

val meta2 = new MetadataBuilder()
.putString("key", "value")
.putExpression("expr", "1 + 2", Some(Add(Literal(1), Literal(2))))
.build()

val meta3 = new MetadataBuilder()
.putString("key", "value")
.putExpression("expr", "2 + 3", Some(Add(Literal(2), Literal(3))))
.build()

val meta4 = new MetadataBuilder()
.putString("key", "value")
.putExpression("expr", "1 + 2", None)
.build()

// meta1 and meta2 are equivalent
assert(meta1 === meta2)
assert(meta1.hashCode === meta2.hashCode)

// meta1 and meta3 are different as they contain different expressions
assert(meta1 !== meta3)
assert(meta1.hashCode !== meta3.hashCode)

// meta1 and meta4 are equivalent even though meta4 only includes the SQL string
assert(meta1 == meta4)
assert(meta1.hashCode == meta4.hashCode)
}

private def checkMetadataExpressions(serializer: SerializerInstance): Unit = {
val meta = new MetadataBuilder()
.putString("key", "value")
.putExpression("tempKey", "1", Some(Literal(1)))
.build()
assert(meta.contains("key"))
assert(meta.getString("key") == "value")
assert(meta.contains("tempKey"))
assert(meta.getExpression[Expression]("tempKey")._1 == "1")
assert(meta.getExpression[Expression]("tempKey")._2.contains(Literal(1)))

val deserializedMeta = serializer.deserialize[Metadata](serializer.serialize(meta))
assert(deserializedMeta == meta)
assert(deserializedMeta.hashCode == meta.hashCode)
assert(deserializedMeta.contains("key"))
assert(deserializedMeta.getString("key") == "value")
assert(deserializedMeta.contains("tempKey"))
assert(deserializedMeta.getExpression[Expression]("tempKey")._1 == "1")
assert(deserializedMeta.getExpression[Expression]("tempKey")._2.isEmpty)
}
}
Loading