Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,16 @@
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
]
},
"MODEL_SIZE_OVERFLOW_EXCEPTION" : {
"message" : [
"<message>"
]
},
"MODEL_CACHE_SIZE_OVERFLOW_EXCEPTION" : {
"message" : [
"<message>"
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.SizeEstimator


private[spark] object GradientBoostedTrees extends Logging {
Expand Down Expand Up @@ -292,6 +293,7 @@ private[spark] object GradientBoostedTrees extends Logging {
featureSubsetStrategy: String,
instr: Option[Instrumentation] = None):
(Array[DecisionTreeRegressionModel], Array[Double]) = {
val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
Expand Down Expand Up @@ -365,7 +367,8 @@ private[spark] object GradientBoostedTrees extends Logging {
val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None)
parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)
.head.asInstanceOf[DecisionTreeRegressionModel]

firstCounts.unpersist()
Expand Down Expand Up @@ -400,11 +403,18 @@ private[spark] object GradientBoostedTrees extends Logging {
timer.stop("init validation")
}

var bestM = 1
var accTreeSize = SizeEstimator.estimate(firstTreeModel)
var validM = 1

var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
var earlyStop = false
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
}
while (m < numIterations && !earlyStop) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
Expand Down Expand Up @@ -434,14 +444,16 @@ private[spark] object GradientBoostedTrees extends Logging {
val model = RandomForest.runBagged(baggedInput = bagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
seed = seed + m, instr = None, parentUID = None)
seed = seed + m, instr = None, parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes - accTreeSize)
.head.asInstanceOf[DecisionTreeRegressionModel]

labelWithCounts.unpersist()

timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
accTreeSize += SizeEstimator.estimate(model)
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
Expand All @@ -466,10 +478,19 @@ private[spark] object GradientBoostedTrees extends Logging {
val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
earlyStop = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
validM = m + 1
}
}
if (!earlyStop) {
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
validM = m + 1
}
}
m += 1
Expand All @@ -490,8 +511,16 @@ private[spark] object GradientBoostedTrees extends Logging {
validatePredErrorCheckpointer.deleteAllCheckpoints()
}

if (validate) {
(baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
if (earlyStop) {
// Early stop occurs in one of the 2 cases:
// - validation error increases
// - the accumulated size of trees exceeds the value of `earlyStopModelSizeThresholdInBytes`
if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
logWarning(
"The boosting tree training stops early because the model size exceeds threshold."
)
}
(baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
} else {
(baseLearners, baseLearnerWeights)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.SizeEstimator
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}

Expand Down Expand Up @@ -109,6 +110,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* @param metadata Learning and dataset metadata for DecisionTree.
* @return an unweighted set of trees
*/
// scalastyle:off
def runBagged(
baggedInput: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
Expand All @@ -119,7 +121,8 @@ private[spark] object RandomForest extends Logging with Serializable {
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
parentUID: Option[String] = None,
earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
timer.start("total")

Expand Down Expand Up @@ -190,7 +193,8 @@ private[spark] object RandomForest extends Logging with Serializable {

timer.stop("init")

while (nodeStack.nonEmpty) {
var earlyStop = false
while (nodeStack.nonEmpty && !earlyStop) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
Expand All @@ -214,6 +218,16 @@ private[spark] object RandomForest extends Logging with Serializable {
}

timer.stop("findBestSplits")

if (
earlyStopModelSizeThresholdInBytes > 0
&& SizeEstimator.estimate(topNodes) > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
logWarning(
"The random forest training stops early because the model size exceeds threshold."
)
}
}

timer.stop("total")
Expand Down Expand Up @@ -253,6 +267,7 @@ private[spark] object RandomForest extends Logging with Serializable {
}
}
}
// scalastyle:on

/**
* Train a random forest.
Expand All @@ -269,6 +284,7 @@ private[spark] object RandomForest extends Logging with Serializable {
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
val timer = new TimeTracker()

timer.start("build metadata")
Expand Down Expand Up @@ -301,7 +317,8 @@ private[spark] object RandomForest extends Logging with Serializable {

val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
seed = seed, instr = instr, prune = prune, parentUID = parentUID)
seed = seed, instr = instr, prune = prune, parentUID = parentUID,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)

baggedInput.unpersist()
bcSplits.destroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,7 @@ private[ml] object EnsembleModelReadWrite {
}
}
}

private[spark] object TreeConfig {
private[spark] var trainingEarlyStopModelSizeThresholdInBytes: Long = 0L
}
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,20 @@ object Connect {
.internal()
.timeConf(TimeUnit.MINUTES)
.createWithDefault(15)

val CONNECT_SESSION_CONNECT_MODEL_MAX_SIZE =
buildConf("spark.connect.session.connectML.model.maxSize")
.doc("Maximum size of the SparkML model. The size is in bytes.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("1g")

val CONNECT_SESSION_CONNECT_MODEL_CACHE_MAX_SIZE =
buildConf("spark.connect.session.connectML.modelCache.maxSize")
.doc("Maximum size of the model cache. The size is in bytes.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("10g")
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,31 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {

private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)

private[ml] val totalModelCacheSizeBytes: AtomicLong = new AtomicLong(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "totalSizeBytes" above is tracking the total size in memory. Why do we need to use a different totalModelCacheSizeBytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totalModelCacheSizeBytes is for counting in-memory + offloaded data size
totalSizeBytes is for counting in-memory size, I will rename it to totalInMemorySizeBytes

private[spark] def getModelCacheMaxSize: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_MODEL_CACHE_MAX_SIZE)
}
private[spark] def getModelMaxSize: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_MODEL_MAX_SIZE)
}

def checkModelSize(estimatedModelSize: Long): Unit = {
if (totalModelCacheSizeBytes.get() + estimatedModelSize > getModelCacheMaxSize) {
throw MLModelCacheSizeOverflowException(
"The model cache size in current session is about to exceed" +
f"$getModelCacheMaxSize bytes. " +
"Please delete existing cached model by executing 'del model' in python client " +
"before fitting new model or loading new model, or increase " +
"Spark config 'spark.connect.session.connectML.model.maxSize'.")
}
if (estimatedModelSize > getModelMaxSize) {
throw MLModelSizeOverflowException(
f"The fitted or loaded model size exceeds $getModelMaxSize bytes. " +
f"Please fit or load a model smaller than $getModelMaxSize bytes. " +
f"or increase Spark config 'spark.connect.session.connectML.modelCache.maxSize'.")
}
}

private def estimateObjectSize(obj: Object): Long = {
obj match {
case model: Model[_] =>
Expand All @@ -81,7 +106,11 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
def register(obj: Object): String = {
val objectId = UUID.randomUUID().toString
val sizeBytes = estimateObjectSize(obj)
checkModelSize(sizeBytes)

totalSizeBytes.addAndGet(sizeBytes)
totalModelCacheSizeBytes.addAndGet(sizeBytes)

cachedModel.put(objectId, CacheItem(obj, sizeBytes))
objectId
}
Expand All @@ -108,6 +137,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
*/
def remove(refId: String): Boolean = {
val removed = cachedModel.remove(refId)
totalModelCacheSizeBytes.addAndGet(-removed.sizeBytes)
// remove returns null if the key is not present
removed != null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ private[spark] case class MLCacheInvalidException(objectName: String)
errorClass = "CONNECT_ML.CACHE_INVALID",
messageParameters = Map("objectName" -> objectName),
cause = null)

private[spark] case class MLModelSizeOverflowException(message: String)
extends SparkException(
errorClass = "CONNECT_ML.MODEL_SIZE_OVERFLOW_EXCEPTION",
messageParameters = Map("message" -> message),
cause = null)

private[spark] case class MLModelCacheSizeOverflowException(message: String)
extends SparkException(
errorClass = "CONNECT_ML.MODEL_CACHE_SIZE_OVERFLOW_EXCEPTION",
messageParameters = Map("message" -> message),
cause = null)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Model
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.tree.TreeConfig
import org.apache.spark.ml.util.{MLWritable, Summary}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
Expand Down Expand Up @@ -116,6 +117,14 @@ private[connect] object MLHandler extends Logging {
mlCommand: proto.MlCommand): proto.MlCommandResult = {

val mlCache = sessionHolder.mlCache
val maxModelSize = sessionHolder.mlCache.getModelCacheMaxSize

// Note: Tree training stops early when the growing tree model exceeds
// `TreeConfig.trainingEarlyStopModelSizeThresholdInBytes`, to ensure the final
// model size is lower than `maxModelSize`, set early-stop threshold to
// half of `maxModelSize`, because in each tree training iteration, the tree
// size will grow up to 2 times size.
TreeConfig.trainingEarlyStopModelSizeThresholdInBytes = maxModelSize / 2

mlCommand.getCommandCase match {
case proto.MlCommand.CommandCase.FIT =>
Expand All @@ -126,6 +135,14 @@ private[connect] object MLHandler extends Logging {
val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder)
val estimator =
MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams))

try {
val estimatedModelSize = estimator.estimateModelSize(dataset)
mlCache.checkModelSize(estimatedModelSize)
} catch {
case _: UnsupportedOperationException => ()
}

val model = estimator.fit(dataset).asInstanceOf[Model[_]]
val id = mlCache.register(model)
proto.MlCommandResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,21 @@ class MLSuite extends MLHelper {
assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
}
}

test("Model size limit") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
sessionHolder.session.conf
.set(Connect.CONNECT_SESSION_CONNECT_MODEL_MAX_SIZE.key, "4000")
intercept[MLModelSizeOverflowException] {
trainLogisticRegressionModel(sessionHolder)
}
sessionHolder.session.conf
.set(Connect.CONNECT_SESSION_CONNECT_MODEL_MAX_SIZE.key, "8000")
sessionHolder.session.conf
.set(Connect.CONNECT_SESSION_CONNECT_MODEL_CACHE_MAX_SIZE.key, "10000")
trainLogisticRegressionModel(sessionHolder)
intercept[MLModelCacheSizeOverflowException] {
trainLogisticRegressionModel(sessionHolder)
}
}
}