Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#26 add create checkpoint function #79

Merged
merged 15 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
6 changes: 4 additions & 2 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class AtumAgent private[agent] () {

/**
* Sends `CheckpointDTO` to the AtumService API
* @param checkpoint
*
* @param checkpoint Already initialized Checkpoint object to store
*/
def saveCheckpoint(checkpoint: CheckpointDTO): Unit = {
dispatcher.saveCheckpoint(checkpoint)
Expand All @@ -45,14 +46,15 @@ class AtumAgent private[agent] () {
/**
* Sends `Checkpoint` to the AtumService API
*
* @param checkpoint
* @param checkpoint Already initialized Checkpoint object to store
*/
def saveCheckpoint(checkpoint: Checkpoint): Unit = {
dispatcher.saveCheckpoint(checkpoint.toCheckpointDTO)
}

/**
* Provides an AtumContext given a `AtumPartitions` instance. Retrieves the data from AtumService API.
*
* @param atumPartitions
* @return
*/
Expand Down
45 changes: 33 additions & 12 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package za.co.absa.atum.agent

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.model.{Checkpoint, Measure, Measurement, MeasuresMapper}
import AtumContext.AtumPartitions
import za.co.absa.atum.model.dto.{AtumContextDTO, PartitionDTO}
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.model.{Checkpoint, Measure, Measurement, MeasurementByAtum, MeasuresMapper}
import za.co.absa.atum.model.dto._

import java.time.ZonedDateTime
import scala.collection.immutable.ListMap
Expand All @@ -42,16 +42,34 @@ class AtumContext private[agent] (
agent.getOrCreateAtumSubContext(atumPartitions ++ subPartitions)(this)
}

def createCheckpoint(checkpointName: String, author: String, dataToMeasure: DataFrame) = {
??? // TODO #26
private def takeMeasurements(df: DataFrame): Set[Measurement] = {
measures.map { m =>
val measurementResult = m.function(df)
MeasurementByAtum(m, measurementResult.result, measurementResult.resultType)
}
}

def createCheckpoint(checkpointName: String, author: String, dataToMeasure: DataFrame): Checkpoint = {
val startTime = ZonedDateTime.now()
val measurements = takeMeasurements(dataToMeasure)
val endTime = ZonedDateTime.now()

Checkpoint(
name = checkpointName,
author = author,
measuredByAtumAgent = true,
atumPartitions = this.atumPartitions,
processStartTime = startTime,
processEndTime = Some(endTime),
measurements = measurements.toSeq
)
}

def createCheckpointOnProvidedData(
checkpointName: String,
author: String,
measurements: Seq[Measurement]
): Checkpoint = {
checkpointName: String, author: String, measurements: Seq[Measurement]
): Checkpoint = {
val zonedDateTimeNow = ZonedDateTime.now()

Checkpoint(
name = checkpointName,
author = author,
Expand Down Expand Up @@ -122,13 +140,16 @@ object AtumContext {
implicit class DatasetWrapper(df: DataFrame) {

/**
* Set a point in the pipeline to execute calculation.
* Set a point in the pipeline to execute calculation and store it.
* @param checkpointName The key assigned to this checkpoint
* @param author Author of the checkpoint
* @param atumContext Contains the calculations to be done and publish the result
* @return
*/
def createCheckpoint(checkpointName: String, author: String)(implicit atumContext: AtumContext): DataFrame = {
// todo: implement checkpoint creation
def createAndSaveCheckpoint(checkpointName: String, author: String)(implicit atumContext: AtumContext): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the rename? I would just keep the existing name.
The saving is kind of implicit, plus mentioned in the comment.

val checkpoint = atumContext.createCheckpoint(checkpointName, author, df)
val checkpointDTO = checkpoint.toCheckpointDTO
atumContext.agent.saveCheckpoint(checkpointDTO)
df
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package za.co.absa.atum.agent.core

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType

trait MeasurementProcessor {

Expand All @@ -26,6 +27,13 @@ trait MeasurementProcessor {
}

object MeasurementProcessor {
type MeasurementFunction = DataFrame => String
/**
* The raw result of measurement is always gonna be string, because we want to avoid some floating point issues
* (overflows, consistent representation of numbers - whether they are coming from Java or Scala world, and more),
* but the actual type is stored alongside the computation because we don't want to lost this information.
*/
final case class ResultOfMeasurement(result: String, resultType: ResultValueType.ResultValueType)

type MeasurementFunction = DataFrame => ResultOfMeasurement

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ class ConsoleDispatcher extends Dispatcher with Logging {
override def saveCheckpoint(checkpoint: CheckpointDTO): Unit = {
println(s"Saving checkpoint to server. $checkpoint")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ class HttpDispatcher(config: Config) extends Dispatcher with Logging {
.post(serverUri)
.send(backend)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@ import java.time.ZonedDateTime
import java.util.UUID

case class Checkpoint(
name: String,
author: String,
measuredByAtumAgent: Boolean = false,
atumPartitions: AtumPartitions,
processStartTime: ZonedDateTime,
processEndTime: Option[ZonedDateTime],
measurements: Seq[Measurement]
) {
name: String,
author: String,
measuredByAtumAgent: Boolean = false,
atumPartitions: AtumPartitions,
processStartTime: ZonedDateTime,
processEndTime: Option[ZonedDateTime],
measurements: Seq[Measurement]
) {
private [agent] def toCheckpointDTO: CheckpointDTO = {
val measurementDTOs = measurements.map {
case provided: MeasurementProvided =>
MeasurementBuilder.buildMeasurementDTO(provided)
case byAtum: MeasurementByAtum =>
MeasurementBuilder.buildMeasurementDTO(byAtum)
}

CheckpointDTO(
id = UUID.randomUUID(),
name = name,
Expand All @@ -40,7 +47,7 @@ case class Checkpoint(
partitioning = AtumPartitions.toSeqPartitionDTO(atumPartitions),
processStartTime = processStartTime,
processEndTime = processEndTime,
measurements = measurements.map(MeasurementBuilder.buildMeasurementDTO)
measurements = measurementDTOs
)
}
}
22 changes: 16 additions & 6 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

/**
Expand All @@ -47,7 +48,10 @@ object Measure {
case class RecordCount private (controlCol: String, measureName: String, onlyForNumeric: Boolean) extends Measure {

override def function: MeasurementFunction =
(ds: DataFrame) => ds.select(col(controlCol)).count().toString
(ds: DataFrame) => {
val result = ds.select(col(controlCol)).count().toString
ResultOfMeasurement(result, ResultValueType.Long)
}
}
object RecordCount extends MeasureType {
def apply(controlCol: String): RecordCount = {
Expand All @@ -62,7 +66,10 @@ object Measure {
extends Measure {

override def function: MeasurementFunction =
(ds: DataFrame) => ds.select(col(controlCol)).distinct().count().toString
(ds: DataFrame) => {
val result = ds.select(col(controlCol)).distinct().count().toString
ResultOfMeasurement(result, ResultValueType.Long)
}
}

object DistinctRecordCount extends MeasureType {
Expand All @@ -79,7 +86,8 @@ object Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(col(valueColumnName))
aggregateColumn(ds, controlCol, aggCol)
val result = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(result, ResultValueType.BigDecimal)
}
}

Expand All @@ -97,7 +105,8 @@ object Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(abs(col(valueColumnName)))
aggregateColumn(ds, controlCol, aggCol)
val result = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(result, ResultValueType.Double)
}
}

Expand All @@ -120,7 +129,8 @@ object Measure {
.withColumn(aggregatedColumnName, crc32(col(controlCol).cast("String")))
.agg(sum(col(aggregatedColumnName)))
.collect()(0)(0)
if (value == null) "" else value.toString
val result = if (value == null) "" else value.toString
ResultOfMeasurement(result, ResultValueType.String)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,22 @@

package za.co.absa.atum.agent.model

case class Measurement(measure: Measure, result: Any)
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType.ResultValueType

trait Measurement {
val measure: Measure
val result: Any
}

/**
* When the application/user of Atum Agent provides actual results by himself, the type is precise and we don't need
* to do any adjustments.
*/
case class MeasurementProvided(measure: Measure, result: Any) extends Measurement

/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: Measure, result: String, resultType: ResultValueType) extends Measurement
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,40 @@ import za.co.absa.atum.model.dto.MeasureResultDTO.{ResultValueType, TypedValue}

private [agent] object MeasurementBuilder {

def buildMeasurementDTO(measurement: Measurement): MeasurementDTO = {
private [agent] def buildMeasurementDTO(measurement: MeasurementByAtum): MeasurementDTO = {
val measureName = measurement.measure.measureName
measurement.result match {
case l: Long =>
buildLongMeasurement(measureName, Seq(measurement.measure.controlCol), l)
case d: Double =>
buildDoubleMeasureResult(measureName, Seq(measurement.measure.controlCol), d)
case bd: BigDecimal =>
buildBigDecimalMeasureResult(measureName, Seq(measurement.measure.controlCol), bd)
case s: String =>
buildStringMeasureResult(measureName, Seq(measurement.measure.controlCol), s)
case unsupportedType =>
val className = unsupportedType.getClass.getSimpleName
throw UnsupportedMeasureResultType(s"Unsupported type of measure $measureName: $className")
}
}
val controlCols = Seq(measurement.measure.controlCol)

private def buildLongMeasurement(functionName: String, controlCols: Seq[String], resultValue: Long): MeasurementDTO = {
MeasurementDTO(
MeasureDTO(functionName, controlCols),
MeasureResultDTO(TypedValue(resultValue.toString, ResultValueType.Long))
MeasureDTO(measureName, controlCols),
MeasureResultDTO(TypedValue(measurement.result, measurement.resultType))
)
}

private def buildDoubleMeasureResult(functionName: String, controlCols: Seq[String], resultValue: Double): MeasurementDTO = {
MeasurementDTO(
MeasureDTO(functionName, controlCols),
MeasureResultDTO(TypedValue(resultValue.toString, ResultValueType.Double))
)
}
private [agent] def buildMeasurementDTO(measurement: MeasurementProvided): MeasurementDTO = {
val measureName = measurement.measure.measureName
val controlCols = Seq(measurement.measure.controlCol)

private def buildBigDecimalMeasureResult(functionName: String, controlCols: Seq[String], resultValue: BigDecimal): MeasurementDTO = {
MeasurementDTO(
MeasureDTO(functionName, controlCols),
MeasureResultDTO(TypedValue(resultValue.toString, ResultValueType.BigDecimal))
MeasureDTO(measureName, controlCols),
buildMeasureResultDTO(measureName, measurement.result)
)
}

private def buildStringMeasureResult(functionName: String, controlCols: Seq[String], resultValue: String): MeasurementDTO = {
MeasurementDTO(
MeasureDTO(functionName, controlCols),
MeasureResultDTO(TypedValue(resultValue, ResultValueType.String))
)
private [agent] def buildMeasureResultDTO(measureName: String, result: Any): MeasureResultDTO = {
result match {
case l: Long =>
MeasureResultDTO(TypedValue(l.toString, ResultValueType.Long))
case d: Double =>
MeasureResultDTO(TypedValue(d.toString, ResultValueType.Double))
case bd: BigDecimal =>
MeasureResultDTO(TypedValue(bd.toString, ResultValueType.BigDecimal))
case s: String =>
MeasureResultDTO(TypedValue(s, ResultValueType.String))
case unsupportedType =>
val className = unsupportedType.getClass.getSimpleName
throw UnsupportedMeasureResultType(s"Unsupported type of measure $measureName: $className for result: $result")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import za.co.absa.atum.model.dto

private [agent] object MeasuresMapper {

def mapToMeasures(measures: Set[dto.MeasureDTO]): Set[za.co.absa.atum.agent.model.Measure] = {
private [agent] def mapToMeasures(measures: Set[dto.MeasureDTO]): Set[za.co.absa.atum.agent.model.Measure] = {
measures.map(createMeasure)
}

Expand Down
Loading