Skip to content

Commit

Permalink
validation checksum during checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvarya-db committed Nov 6, 2024
1 parent c60437b commit 284427c
Show file tree
Hide file tree
Showing 11 changed files with 381 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ trait CheckpointProvider extends UninitializedCheckpointProvider {
* for the checkpoint.
*/
def allActionsFileIndexes(): Seq[DeltaLogFileIndex]

/**
* The type of checkpoint (V2 vs Classic). This will be None when no checkpoint is available.
* This is only intended to be used for logging and metrics.
*/
def checkpointPolicy: Option[CheckpointPolicy.Policy]
}

object CheckpointProvider extends DeltaLogging {
Expand Down Expand Up @@ -285,6 +291,8 @@ case class PreloadedCheckpointProvider(
override def allActionsFileIndexes(): Seq[DeltaLogFileIndex] = Seq(fileIndex)

override lazy val topLevelFileIndex: Option[DeltaLogFileIndex] = Some(fileIndex)

override def checkpointPolicy: Option[CheckpointPolicy.Policy] = Some(CheckpointPolicy.Classic)
}

/**
Expand All @@ -302,6 +310,7 @@ object EmptyCheckpointProvider extends CheckpointProvider {
override def effectiveCheckpointSizeInBytes(): Long = 0L
override def allActionsFileIndexes(): Seq[DeltaLogFileIndex] = Nil
override def topLevelFileIndex: Option[DeltaLogFileIndex] = None
override def checkpointPolicy: Option[CheckpointPolicy.Policy] = None
}

/** A trait representing a v2 [[UninitializedCheckpointProvider]] */
Expand Down Expand Up @@ -413,6 +422,9 @@ abstract class LazyCompleteCheckpointProvider(

override def allActionsFileIndexes(): Seq[DeltaLogFileIndex] =
underlyingCheckpointProvider.allActionsFileIndexes()

override def checkpointPolicy: Option[CheckpointPolicy.Policy] =
underlyingCheckpointProvider.checkpointPolicy
}

/**
Expand Down Expand Up @@ -458,6 +470,8 @@ case class V2CheckpointProvider(
override def allActionsFileIndexes(): Seq[DeltaLogFileIndex] =
topLevelFileIndex ++: fileIndexesForSidecarFiles

override def checkpointPolicy: Option[CheckpointPolicy.Policy] = Some(CheckpointPolicy.V2)

}

object V2CheckpointProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,10 @@ object Checkpoints
deltaLog: DeltaLog,
snapshot: Snapshot): LastCheckpointInfo = recordFrameProfile(
"Delta", "Checkpoints.writeCheckpoint") {
if (spark.conf.get(DeltaSQLConf.DELTA_WRITE_CHECKSUM_ENABLED)) {
snapshot.validateChecksum(Map("context" -> "writeCheckpoint"))
}

val hadoopConf = deltaLog.newDeltaHadoopConf()

// The writing of checkpoints doesn't go through log store, so we need to check with the
Expand Down
173 changes: 169 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/delta/Checksum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.FileNotFoundException
import java.nio.charset.StandardCharsets.UTF_8

// scalastyle:off import.ordering.noEmptyLine
import scala.collection.immutable.ListMap
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
Expand All @@ -35,10 +36,11 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkEnv
import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.util.Utils
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* Stats calculated within a snapshot, which we store along individual transactions for
Expand Down Expand Up @@ -76,13 +78,17 @@ trait RecordChecksum extends DeltaLogging {
private lazy val writer =
CheckpointFileManager.create(deltaLog.logPath, deltaLog.newDeltaHadoopConf())

private def getChecksum(snapshot: Snapshot): VersionChecksum = {
snapshot.checksumOpt.getOrElse(snapshot.computeChecksum)
}

protected def writeChecksumFile(txnId: String, snapshot: Snapshot): Unit = {
if (!spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_WRITE_CHECKSUM_ENABLED)) {
return
}

val version = snapshot.version
val checksumWithoutTxnId = snapshot.checksumOpt.getOrElse(snapshot.computeChecksum)
val checksumWithoutTxnId = getChecksum(snapshot)
val checksum = checksumWithoutTxnId.copy(txnId = Some(txnId))
val eventData = mutable.Map[String, Any]("operationSucceeded" -> false)
eventData("numAddFileActions") = checksum.allFiles.map(_.size).getOrElse(-1)
Expand Down Expand Up @@ -168,8 +174,7 @@ trait RecordChecksum extends DeltaLogging {
// then we cannot incrementally derive a new checksum for the new snapshot.
logInfo(log"Incremental commit: starting with snapshot version " +
log"${MDC(DeltaLogKeys.VERSION, expectedVersion)}")
val snapshotChecksum = snapshot.checksumOpt.getOrElse(snapshot.computeChecksum)
snapshotChecksum.copy(numMetadata = 1, numProtocol = 1) -> Some(snapshot)
getChecksum(snapshot).copy(numMetadata = 1, numProtocol = 1) -> Some(snapshot)
case _ =>
previousVersionState.swap.foreach { snapshot =>
// Occurs when snapshot is no longer fresh due to concurrent writers.
Expand Down Expand Up @@ -366,3 +371,163 @@ trait ReadChecksum extends DeltaLogging { self: DeltaLog =>
}
}

/**
* Verify the state of the table using the checksum information.
*/
trait ValidateChecksum extends DeltaLogging { self: Snapshot =>

/**
* Validate checksum (if any) by comparing it against the snapshot's state reconstruction.
* @param contextInfo caller context that will be added to the logging if validation fails
* @return True iff validation succeeded.
* @throws IllegalStateException if validation failed and corruption is configured as fatal.
*/
def validateChecksum(contextInfo: Map[String, String] = Map.empty): Boolean = {
val contextSuffix = contextInfo.get("context").map(c => s".context-$c").getOrElse("")
val computedStateAccessor = s"ValidateChecksum.checkMismatch$contextSuffix"
val computedStateToCompareAgainst = computedState
val (mismatchErrorMap, detailedErrorMapForUsageLogs) = checksumOpt
.map(checkMismatch(_, computedStateToCompareAgainst))
.getOrElse((Map.empty[String, String], Map.empty[String, String]))
logAndThrowValidationFailure(mismatchErrorMap, detailedErrorMapForUsageLogs, contextInfo)
}

private def logAndThrowValidationFailure(
mismatchErrorMap: Map[String, String],
detailedErrorMapForUsageLogs: Map[String, String],
contextInfo: Map[String, String]): Boolean = {
if (mismatchErrorMap.isEmpty) return true
val mismatchString = mismatchErrorMap.values.mkString("\n")

// We get the active SparkSession, which may be different than the SparkSession of the
// Snapshot that was created, since we cache `DeltaLog`s.
val sparkOpt = SparkSession.getActiveSession

// Report the failure to usage logs.
recordDeltaEvent(
this.deltaLog,
"delta.checksum.invalid",
data = Map(
"error" -> mismatchString,
"mismatchingFields" -> mismatchErrorMap.keys.toSeq,
"detailedErrorMap" -> detailedErrorMapForUsageLogs,
"v2CheckpointEnabled" ->
CheckpointProvider.isV2CheckpointEnabled(this),
"checkpointProviderCheckpointPolicy" ->
checkpointProvider.checkpointPolicy.map(_.name).getOrElse("")
) ++ contextInfo)

val spark = sparkOpt.getOrElse {
throw DeltaErrors.sparkSessionNotSetException()
}
if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_CHECKSUM_MISMATCH_IS_FATAL)) {
throw DeltaErrors.logFailedIntegrityCheck(version, mismatchString)
}
false
}

/**
* Validates the given `checksum` against [[Snapshot.computedState]].
* Returns an tuple of Maps:
* - first Map contains fieldName to user facing errorMessage mapping.
* - second Map is just for usage logs purpose and contains more details for different fields.
* Adding info to this map is optional.
*/
private def checkMismatch(
checksum: VersionChecksum,
computedStateToCheckAgainst: SnapshotState
): (Map[String, String], Map[String, String]) = {
var errorMap = ListMap[String, String]()
var detailedErrorMapForUsageLogs = ListMap[String, String]()

def compare(expected: Long, found: Long, title: String, field: String): Unit = {
if (expected != found) {
errorMap += (field -> s"$title - Expected: $expected Computed: $found")
}
}
def compareAction(expected: Action, found: Action, title: String, field: String): Unit = {
// only compare when expected is not null for being backward compatible to the checksum
// without protocol and metadata
Option(expected).filterNot(_.equals(found)).foreach { expected =>
errorMap += (field -> s"$title - Expected: $expected Computed: $found")
}
}

def compareSetTransactions(
setTransactionsInCRC: Seq[SetTransaction],
setTransactionsComputed: Seq[SetTransaction]): Unit = {
val appIdsFromCrc = setTransactionsInCRC.map(_.appId)
val repeatedEntriesForSameAppId = appIdsFromCrc.size != appIdsFromCrc.toSet.size
val setTransactionsInCRCSet = setTransactionsInCRC.toSet
val setTransactionsFromComputeStateSet = setTransactionsComputed.toSet
val exactMatchFailed = setTransactionsInCRCSet != setTransactionsFromComputeStateSet
if (repeatedEntriesForSameAppId || exactMatchFailed) {
val repeatedAppIds = appIdsFromCrc.groupBy(identity).filter(_._2.size > 1).keySet.toSeq
val matchedActions = setTransactionsInCRCSet.intersect(setTransactionsFromComputeStateSet)
val unmatchedActionsInCrc = setTransactionsInCRCSet -- matchedActions
val unmatchedActionsInComputed = setTransactionsFromComputeStateSet -- matchedActions
val eventData = Map(
"unmatchedSetTransactionsCRC" -> unmatchedActionsInCrc,
"unmatchedSetTransactionsComputedState" -> unmatchedActionsInComputed,
"version" -> version,
"minSetTransactionRetentionTimestamp" -> minSetTransactionRetentionTimestamp,
"repeatedEntriesForSameAppId" -> repeatedAppIds,
"exactMatchFailed" -> exactMatchFailed)
errorMap += ("setTransactions" -> s"SetTransaction mismatch")
detailedErrorMapForUsageLogs += ("setTransactions" -> JsonUtils.toJson(eventData))
}
}

def compareDomainMetadata(
domainMetadataInCRC: Seq[DomainMetadata],
domainMetadataComputed: Seq[DomainMetadata]): Unit = {
val domainMetadataInCRCSet = domainMetadataInCRC.toSet
// Remove any tombstones from the reconstructed set before comparison.
val domainMetadataInComputeStateSet = domainMetadataComputed.filterNot(_.removed).toSet
val exactMatchFailed = domainMetadataInCRCSet != domainMetadataInComputeStateSet
if (exactMatchFailed) {
val matchedActions = domainMetadataInCRCSet.intersect(domainMetadataInComputeStateSet)
val unmatchedActionsInCRC = domainMetadataInCRCSet -- matchedActions
val unmatchedActionsInComputed = domainMetadataInComputeStateSet -- matchedActions
val eventData = Map(
"unmatchedDomainMetadataInCRC" -> unmatchedActionsInCRC,
"unmatchedDomainMetadataInComputedState" -> unmatchedActionsInComputed,
"version" -> version)
errorMap += ("domainMetadata" -> "domainMetadata mismatch")
detailedErrorMapForUsageLogs += ("domainMetadata" -> JsonUtils.toJson(eventData))
}
}

compareAction(checksum.metadata, computedStateToCheckAgainst.metadata, "Metadata", "metadata")
compareAction(checksum.protocol, computedStateToCheckAgainst.protocol, "Protocol", "protocol")
compare(
checksum.tableSizeBytes,
computedStateToCheckAgainst.sizeInBytes,
title = "Table size (bytes)",
field = "tableSizeBytes")
compare(
checksum.numFiles,
computedStateToCheckAgainst.numOfFiles,
title = "Number of files",
field = "numFiles")
compare(
checksum.numMetadata,
computedStateToCheckAgainst.numOfMetadata,
title = "Metadata updates",
field = "numOfMetadata")
compare(
checksum.numProtocol,
computedStateToCheckAgainst.numOfProtocol,
title = "Protocol updates",
field = "numOfProtocol")

checksum.setTransactions.foreach { setTransactionsInCRC =>
compareSetTransactions(setTransactionsInCRC, computedStateToCheckAgainst.setTransactions)
}

checksum.domainMetadata.foreach(
compareDomainMetadata(_, computedStateToCheckAgainst.domainMetadata))

(errorMap, detailedErrorMapForUsageLogs)
}
}
11 changes: 6 additions & 5 deletions spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ class DeltaLog private(
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_MAX_SNAPSHOT_LINEAGE_LENGTH)

private[delta] def incrementalCommitEnabled: Boolean = {
DeltaLog.incrementalCommitEnableConfigs.forall(conf => spark.conf.get(conf))
spark.conf.get(DeltaSQLConf.INCREMENTAL_COMMIT_ENABLED)
}

private[delta] def shouldVerifyIncrementalCommit: Boolean = {
spark.conf.get(DeltaSQLConf.INCREMENTAL_COMMIT_VERIFY) ||
(Utils.isTesting && spark.conf.get(DeltaSQLConf.INCREMENTAL_COMMIT_FORCE_VERIFY_IN_TESTS))
}

/** The unique identifier for this table. */
Expand Down Expand Up @@ -724,10 +729,6 @@ object DeltaLog extends DeltaLogging {
.maximumSize(cacheSize)
}

val incrementalCommitEnableConfigs = Seq(
DeltaSQLConf.INCREMENTAL_COMMIT_ENABLED
)


/**
* Creates a [[LogicalRelation]] for a given [[DeltaLogFileIndex]], with all necessary file source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,30 @@ object OptimisticTransaction {
trait OptimisticTransactionImpl extends TransactionalWrite
with SQLMetricsReporting
with DeltaScanGenerator
with DeltaLogging
with RecordChecksum {
with RecordChecksum
with DeltaLogging {

import org.apache.spark.sql.delta.util.FileNames._

// Intentionally cache the values of these configs to ensure stable commit code path
// and avoid race conditions between committing and dynamic config changes.
protected val incrementalCommitEnabled = deltaLog.incrementalCommitEnabled
protected val shouldVerifyIncrementalCommit = deltaLog.shouldVerifyIncrementalCommit

val deltaLog: DeltaLog
val catalogTable: Option[CatalogTable]
val snapshot: Snapshot
def clock: Clock = deltaLog.clock

// This would be a quick operation if we already validated the checksum
// Otherwise, we should at least perform the validation here.
// NOTE: When incremental commits are enabled, skip validation unless it was specifically
// requested. This allows us to maintain test converage internally, while avoiding the extreme
// overhead of those checks in prod or benchmark settings.
if (!incrementalCommitEnabled || shouldVerifyIncrementalCommit) {
snapshot.validateChecksum(Map("context" -> "transactionInitialization"))
}

protected def spark = SparkSession.active

/** Tracks the appIds that have been seen by this transaction. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class Snapshot(
with StateCache
with StatisticsCollection
with DataSkippingReader
with ValidateChecksum
with DeltaLogging {

import Snapshot._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1250,12 +1250,37 @@ trait SnapshotManagement { self: DeltaLog =>
logInfo(
log"Creating a new snapshot v${MDC(DeltaLogKeys.VERSION, initSegment.version)} " +
log"for commit version ${MDC(DeltaLogKeys.VERSION2, committedVersion)}")
createSnapshot(

val newSnapshot = createSnapshot(
initSegment,
tableCommitCoordinatorClientOpt = tableCommitCoordinatorClientOpt,
tableIdentifierOpt = tableIdentifierOpt,
checksumOpt = newChecksumOpt
)
// Verify when enabled or when tests run to help future proof IC
if (shouldVerifyIncrementalCommit) {
val crcIsValid = try {
// NOTE: Validation is a no-op with incremental commit disabled.
newSnapshot.validateChecksum(Map("context" -> "incrementalCommit"))
} catch {
case _: IllegalStateException if !Utils.isTesting => false
}

if (!crcIsValid) {
// Create snapshot without incremental checksum. This will fallback to creating
// a checksum based on state reconstruction. Disable incremental commit to avoid
// further error triggers in this session.
spark.sessionState.conf.setConf(DeltaSQLConf.INCREMENTAL_COMMIT_ENABLED, false)
return createSnapshotAfterCommit(
initSegment,
newChecksumOpt = None,
tableCommitCoordinatorClientOpt = tableCommitCoordinatorClientOpt,
tableIdentifierOpt,
committedVersion)
}
}

newSnapshot
}

/**
Expand Down
Loading

0 comments on commit 284427c

Please sign in to comment.