diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a6d6fe6b726b3..78f29dde6b8c8 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -811,6 +811,21 @@ "Cannot retrieve from the ML cache. It is probably because the entry has been evicted." ] }, + "ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : { + "message" : [ + "The model cache size in current session is about to exceed", + " bytes.", + "Please delete existing cached model by executing 'del model' in python client before fitting new model or loading new model,", + "or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxSize'." + ] + }, + "MODEL_SIZE_OVERFLOW_EXCEPTION" : { + "message" : [ + "The fitted or loaded model size is about bytes.", + "Please fit or load a model smaller than bytes,", + "or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxModelSize'." + ] + }, "UNSUPPORTED_EXCEPTION" : { "message" : [ "" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 534c54c567589..e400bf13eb8dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -195,6 +195,8 @@ class DecisionTreeClassificationModel private[ml] ( // For ml connect only private[ml] def this() = this("", Node.dummyNode, -1, -1) + override def estimatedSize: Long = getEstimatedSize() + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 222cfbb80c3d1..c815174b8c42f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -224,15 +224,18 @@ class FMClassifier @Since("3.0.0") ( val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors)) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val summary = new FMClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) + if (SummaryUtils.enableTrainingSummary) { + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val summary = new FMClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) + } + model } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 9f2c2c85115ba..9ca3a16609586 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -276,6 +276,8 @@ class GBTClassificationModel private[ml]( private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1, -1) + override def estimatedSize: Long = getEstimatedSize() + @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index c5d1170318f79..6a3658537c3c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -277,15 +277,18 @@ class LinearSVC @Since("2.2.0") ( val model = copyValues(new LinearSVCModel(uid, coefficients, intercept)) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() - val summary = new LinearSVCTrainingSummaryImpl( - summaryModel.transform(dataset), - rawPredictionColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) + if (SummaryUtils.enableTrainingSummary) { + val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() + val summary = new LinearSVCTrainingSummaryImpl( + summaryModel.transform(dataset), + rawPredictionColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) + } + model } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d09cacf3fb5b5..fd9bc6f0131fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -711,27 +711,30 @@ class LogisticRegression @Since("1.2.0") ( numClasses, checkMultinomial(numClasses))) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val logRegSummary = if (numClasses <= 2) { - new BinaryLogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } else { - new LogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) + if (SummaryUtils.enableTrainingSummary) { + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val logRegSummary = if (numClasses <= 2) { + new BinaryLogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } else { + new LogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } + model.setSummary(Some(logRegSummary)) } - model.setSummary(Some(logRegSummary)) + model } private def createBounds( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 2359749f8b48b..5f163c43b4027 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -249,14 +249,17 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = { val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights)) - val (summaryModel, _, predictionColName) = model.findSummaryModel() - val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - "", - objectiveHistory) - model.setSummary(Some(summary)) + if (SummaryUtils.enableTrainingSummary) { + val (summaryModel, _, predictionColName) = model.findSummaryModel() + val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + "", + objectiveHistory) + model.setSummary(Some(summary)) + } + model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 24dd7095513ac..391d61a9793f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -185,23 +185,26 @@ class RandomForestClassifier @Since("1.4.0") ( val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val rfSummary = if (numClasses <= 2) { - new BinaryRandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } else { - new RandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) + if (SummaryUtils.enableTrainingSummary) { + val rfSummary = if (numClasses <= 2) { + new BinaryRandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } else { + new RandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } + model.setSummary(Some(rfSummary)) } - model.setSummary(Some(rfSummary)) + model } @Since("1.4.1") @@ -258,6 +261,8 @@ class RandomForestClassificationModel private[ml] ( // For ml connect only private[ml] def this() = this("", Array(new DecisionTreeClassificationModel), -1, -1) + override def estimatedSize: Long = getEstimatedSize() + @Since("1.4.0") override def trees: Array[DecisionTreeClassificationModel] = _trees diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f1cd126a6406d..d63a0663690e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -303,16 +303,19 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.runWithWeight(instances, handlePersistence, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) - val summary = new BisectingKMeansSummary( - model.transform(dataset), - $(predictionCol), - $(featuresCol), - $(k), - $(maxIter), - parentModel.trainingCost) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logNumFeatures(model.clusterCenters.head.size) - model.setSummary(Some(summary)) + if (SummaryUtils.enableTrainingSummary) { + val summary = new BisectingKMeansSummary( + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + $(maxIter), + parentModel.trainingCost) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + instr.logNumFeatures(model.clusterCenters.head.size) + model.setSummary(Some(summary)) + } + model } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ee0b19f8129dc..e937313af95d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -430,11 +430,14 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)) .setParent(this) - val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) - instr.logNamedValue("logLikelihood", logLikelihood) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - model.setSummary(Some(summary)) + if (SummaryUtils.enableTrainingSummary) { + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) + instr.logNamedValue("logLikelihood", logLikelihood) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + model.setSummary(Some(summary)) + } + model } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e87dc9eb040b6..9dcb43a36120c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -391,16 +391,18 @@ class KMeans @Since("1.5.0") ( } val model = copyValues(new KMeansModel(uid, oldModel).setParent(this)) - val summary = new KMeansSummary( - model.transform(dataset), - $(predictionCol), - $(featuresCol), - $(k), - oldModel.numIter, - oldModel.trainingCost) - - model.setSummary(Some(summary)) - instr.logNamedValue("clusterSizes", summary.clusterSizes) + if (SummaryUtils.enableTrainingSummary) { + val summary = new KMeansSummary( + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + oldModel.numIter, + oldModel.trainingCost) + + model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + } model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 50de0c54b8c3f..f53d26882d3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -190,6 +190,8 @@ class DecisionTreeRegressionModel private[ml] ( // For ml connect only private[ml] def this() = this("", Node.dummyNode, -1) + override def estimatedSize: Long = getEstimatedSize() + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f71eea6c62933..163921c703903 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -245,6 +245,8 @@ class GBTRegressionModel private[ml]( // For ml connect only private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1) + override def estimatedSize: Long = getEstimatedSize() + @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 777b70e7d0213..d1def59523960 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -418,9 +418,12 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - wlsModel.diagInvAtWA.toArray, 1, getSolver) - model.setSummary(Some(trainingSummary)) + if (SummaryUtils.enableTrainingSummary) { + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + wlsModel.diagInvAtWA.toArray, 1, getSolver) + model.setSummary(Some(trainingSummary)) + } + model } else { val instances = validated.rdd.map { case Row(label: Double, weight: Double, offset: Double, features: Vector) => @@ -435,9 +438,12 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(Some(trainingSummary)) + if (SummaryUtils.enableTrainingSummary) { + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) + model.setSummary(Some(trainingSummary)) + } + model } model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 919cec4e16e4a..29fa3d5e123be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -431,13 +431,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - model, Array(0.0), objectiveHistory) - model.setSummary(Some(trainingSummary)) + if (SummaryUtils.enableTrainingSummary) { + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + model, Array(0.0), objectiveHistory) + model.setSummary(Some(trainingSummary)) + } + model } private def trainWithNormal( diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 97d0f54d0eca4..04e36c7f27ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -215,6 +215,8 @@ class RandomForestRegressionModel private[ml] ( // For ml connect only private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), -1) + override def estimatedSize: Long = getEstimatedSize() + @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 5184109bd3a52..18273abbf0c7c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -273,6 +273,10 @@ private[spark] object GradientBoostedTrees extends Logging { errSum.map(_ / weightSum) } + // This member is only for testing code. + private[spark] var lastEarlyStoppedModelSize: Long = 0 + private[spark] val modelSizeHistory = new scala.collection.mutable.ArrayBuffer[Long]() + /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -292,6 +296,8 @@ private[spark] object GradientBoostedTrees extends Logging { featureSubsetStrategy: String, instr: Option[Instrumentation] = None): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes + lastEarlyStoppedModelSize = 0 val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -365,7 +371,8 @@ private[spark] object GradientBoostedTrees extends Logging { val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged, metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1, featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr, - parentUID = None) + parentUID = None, + earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes) .head.asInstanceOf[DecisionTreeRegressionModel] firstCounts.unpersist() @@ -400,11 +407,22 @@ private[spark] object GradientBoostedTrees extends Logging { timer.stop("init validation") } - var bestM = 1 + var accTreeSize = firstTreeModel.estimatedSize + + var validM = 1 var m = 1 - var doneLearning = false - while (m < numIterations && !doneLearning) { + var earlyStop = false + modelSizeHistory.clear() + modelSizeHistory.append(accTreeSize) + if ( + earlyStopModelSizeThresholdInBytes > 0 + && accTreeSize > earlyStopModelSizeThresholdInBytes + ) { + lastEarlyStoppedModelSize = accTreeSize + earlyStop = true + } + while (m < numIterations && !earlyStop) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -434,7 +452,8 @@ private[spark] object GradientBoostedTrees extends Logging { val model = RandomForest.runBagged(baggedInput = bagged, metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1, featureSubsetStrategy = featureSubsetStrategy, - seed = seed + m, instr = None, parentUID = None) + seed = seed + m, instr = None, parentUID = None, + earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes - accTreeSize) .head.asInstanceOf[DecisionTreeRegressionModel] labelWithCounts.unpersist() @@ -442,6 +461,7 @@ private[spark] object GradientBoostedTrees extends Logging { timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model + accTreeSize += model.estimatedSize // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. @@ -466,10 +486,21 @@ private[spark] object GradientBoostedTrees extends Logging { val currentValidateError = computeWeightedError(validationTreePoints, validatePredError) if (bestValidateError - currentValidateError < validationTol * Math.max( currentValidateError, 0.01)) { - doneLearning = true + earlyStop = true } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError - bestM = m + 1 + validM = m + 1 + } + } + if (!earlyStop) { + modelSizeHistory.append(accTreeSize) + if ( + earlyStopModelSizeThresholdInBytes > 0 + && accTreeSize > earlyStopModelSizeThresholdInBytes + ) { + earlyStop = true + validM = m + 1 + lastEarlyStoppedModelSize = accTreeSize } } m += 1 @@ -490,8 +521,16 @@ private[spark] object GradientBoostedTrees extends Logging { validatePredErrorCheckpointer.deleteAllCheckpoints() } - if (validate) { - (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) + if (earlyStop) { + // Early stop occurs in one of the 2 cases: + // - validation error increases + // - the accumulated size of trees exceeds the value of `earlyStopModelSizeThresholdInBytes` + if (accTreeSize > earlyStopModelSizeThresholdInBytes) { + logWarning( + "The boosting tree training stops early because the model size exceeds threshold." + ) + } + (baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM)) } else { (baseLearners, baseLearnerWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 452532df5a2b6..79d0964b7eae9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.SizeEstimator import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -101,6 +102,9 @@ private[spark] object RandomForest extends Logging with Serializable { run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) } + // This member is only for testing code. + private[spark] var lastEarlyStoppedModelSize: Long = 0 + /** * Train a random forest with metadata and splits. This method is mainly for GBT, * in which bagged input can be reused among trees. @@ -109,6 +113,7 @@ private[spark] object RandomForest extends Logging with Serializable { * @param metadata Learning and dataset metadata for DecisionTree. * @return an unweighted set of trees */ + // scalastyle:off def runBagged( baggedInput: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, @@ -119,7 +124,9 @@ private[spark] object RandomForest extends Logging with Serializable { seed: Long, instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned - parentUID: Option[String] = None): Array[DecisionTreeModel] = { + parentUID: Option[String] = None, + earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel] = { + lastEarlyStoppedModelSize = 0 val timer = new TimeTracker() timer.start("total") @@ -190,7 +197,8 @@ private[spark] object RandomForest extends Logging with Serializable { timer.stop("init") - while (nodeStack.nonEmpty) { + var earlyStop = false + while (nodeStack.nonEmpty && !earlyStop) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = @@ -214,8 +222,19 @@ private[spark] object RandomForest extends Logging with Serializable { } timer.stop("findBestSplits") - } + if (earlyStopModelSizeThresholdInBytes > 0) { + val nodes = topNodes.map(_.toNode(prune)) + val estimatedSize = SizeEstimator.estimate(nodes) + if (estimatedSize > earlyStopModelSizeThresholdInBytes){ + earlyStop = true + logWarning( + "The random forest training stops early because the model size exceeds threshold." + ) + lastEarlyStoppedModelSize = estimatedSize + } + } + } timer.stop("total") logInfo("Internal timing for DecisionTree:") @@ -253,6 +272,7 @@ private[spark] object RandomForest extends Logging with Serializable { } } } + // scalastyle:on /** * Train a random forest. @@ -269,6 +289,7 @@ private[spark] object RandomForest extends Logging with Serializable { instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { + val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes val timer = new TimeTracker() timer.start("build metadata") @@ -301,7 +322,8 @@ private[spark] object RandomForest extends Logging with Serializable { val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits, strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy, - seed = seed, instr = instr, prune = prune, parentUID = parentUID) + seed = seed, instr = instr, prune = prune, parentUID = parentUID, + earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes) baggedInput.unpersist() bcSplits.destroy() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 73a54405c752d..ea9a0de165638 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -116,6 +116,10 @@ private[spark] trait DecisionTreeModel { def predictLeaf(features: Vector): Double = { leafIndices(rootNode.predictImpl(features)).toDouble } + + def getEstimatedSize(): Long = { + org.apache.spark.util.SizeEstimator.estimate(rootNode) + } } /** @@ -168,6 +172,10 @@ private[spark] trait TreeEnsembleModel[M <: DecisionTreeModel] { private[ml] def getLeafField(leafCol: String) = { new AttributeGroup(leafCol, attrs = trees.map(_.leafAttr)).toStructField() } + + def getEstimatedSize(): Long = { + org.apache.spark.util.SizeEstimator.estimate(trees.map(_.rootNode)) + } } private[ml] object TreeEnsembleModel { @@ -583,3 +591,7 @@ private[ml] object EnsembleModelReadWrite { } } } + +private[spark] object TreeConfig { + private[spark] var trainingEarlyStopModelSizeThresholdInBytes: Long = 0L +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala index 0ba8ce072ab4a..a6d09c2e9142e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -50,3 +50,9 @@ private[spark] trait HasTrainingSummary[T] { this } } + +private[spark] object SummaryUtils { + + // This flag is only used by Spark Connect + private[spark] var enableTrainingSummary: Boolean = true +} diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py b/python/pyspark/ml/tests/connect/test_parity_tuning.py index 9db783666d221..7e098b9b909e4 100644 --- a/python/pyspark/ml/tests/connect/test_parity_tuning.py +++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py @@ -29,9 +29,9 @@ class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConn @classmethod def conf(cls): conf = super().conf() - conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "true") - conf.set("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize", "1024") - conf.set("spark.connect.session.connectML.mlCache.offloading.timeout", "1") + conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "true") + conf.set("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize", "1024") + conf.set("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout", "1") return conf diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index ce727e78605e6..667f2b9ada060 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -160,7 +160,7 @@ def conf(cls): conf.set("spark.connect.authenticate.token", "deadbeef") # Disable ml cache offloading, # offloading hasn't supported APIs like model.summary / model.evaluate - conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "false") + conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "false") return conf @classmethod diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 5b2daae1c0e47..33feca5b0c940 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -334,36 +334,51 @@ object Connect { } } - val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE = - buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize") + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.enabled") + .doc("Enables ML cache memory control, it includes offloading model to disk, " + + "limiting model size, and limiting per-session ML cache size.") + .version("4.1.0") + .internal() + .booleanConf + .createWithDefault(true) + + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_IN_MEMORY_SIZE = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize") .doc( - "In-memory maximum size of the MLCache per session. The cache will offload the least " + + "Maximum in-memory size of the MLCache per session. The cache will offload the least " + "recently used models to Spark driver local disk if the size exceeds this limit. " + - "The size is in bytes. This configuration only works when " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") + "The size is in bytes.") .version("4.1.0") .internal() .bytesConf(ByteUnit.BYTE) // By default, 1/3 of total designated memory (the configured -Xmx). .createWithDefault(Runtime.getRuntime.maxMemory() / 3) - val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT = - buildConf("spark.connect.session.connectML.mlCache.offloading.timeout") + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_OFFLOADING_TIMEOUT = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout") .doc( "Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " + - "disk if they are not used for this amount of time. The timeout is in minutes. " + - "This configuration only works when " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") + "disk if they are not used for this amount of time. The timeout is in minutes.") .version("4.1.0") .internal() .timeConf(TimeUnit.MINUTES) .createWithDefault(5) - val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED = - buildConf("spark.connect.session.connectML.mlCache.offloading.enabled") - .doc("Enables ML cache offloading.") + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxModelSize") + .doc("Maximum size of a single SparkML model. The size is in bytes.") .version("4.1.0") .internal() - .booleanConf - .createWithDefault(true) + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxSize") + .doc("Maximum total size (including in-memory and offloaded data) of the ml cache. " + + "The size is in bytes.") + .version("4.1.0") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10g") } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 2379284b62b02..d59525a6190c3 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -41,9 +41,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private val helperID = "______ML_CONNECT_HELPER______" private val modelClassNameFile = "__model_class_name__" - // TODO: rename it to `totalInMemorySizeBytes` because it only counts the in-memory - // part data size. - private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) + private[ml] val totalMLCacheInMemorySizeBytes: AtomicLong = new AtomicLong(0) val offloadedModelsDir: Path = { val path = Paths.get( @@ -52,27 +50,29 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { sessionHolder.sessionId) Files.createDirectories(path) } - private[spark] def getOffloadingEnabled: Boolean = { - sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED) + private[spark] def getMemoryControlEnabled: Boolean = { + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED) } private def getMaxInMemoryCacheSizeKB: Long = { sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE) / 1024 + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_IN_MEMORY_SIZE) / 1024 } private def getOffloadingTimeoutMinute: Long = { - sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT) + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_OFFLOADING_TIMEOUT) } private[ml] case class CacheItem(obj: Object, sizeBytes: Long) private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = { - if (getOffloadingEnabled) { + if (getMemoryControlEnabled) { CacheBuilder .newBuilder() .softValues() .removalListener((removed: RemovalNotification[String, CacheItem]) => - totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)) + totalMLCacheInMemorySizeBytes.addAndGet(-removed.getValue.sizeBytes)) .maximumWeight(getMaxInMemoryCacheSizeKB) .weigher((key: String, value: CacheItem) => { Math.ceil(value.sizeBytes.toDouble / 1024).toInt @@ -89,6 +89,25 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { new ConcurrentHashMap[String, Summary]() } + private[ml] val totalMLCacheSizeBytes: AtomicLong = new AtomicLong(0) + private[spark] def getMLCacheMaxSize: Long = { + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE) + } + private[spark] def getModelMaxSize: Long = { + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE) + } + + def checkModelSize(estimatedModelSize: Long): Unit = { + if (totalMLCacheSizeBytes.get() + estimatedModelSize > getMLCacheMaxSize) { + throw MLCacheSizeOverflowException(getMLCacheMaxSize) + } + if (estimatedModelSize > getModelMaxSize) { + throw MLModelSizeOverflowException(estimatedModelSize, getModelMaxSize) + } + } + private def estimateObjectSize(obj: Object): Long = { obj match { case model: Model[_] => @@ -100,12 +119,12 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } private[spark] def checkSummaryAvail(): Unit = { - if (getOffloadingEnabled) { + if (getMemoryControlEnabled) { throw MlUnsupportedException( "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + "'pyspark.ml.evaluation' instead, or you can set Spark config " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " + + "'spark.connect.session.connectML.mlCache.memoryControl.enabled' to 'false' to " + "disable Spark Connect session ML cache offloading.") } } @@ -124,18 +143,21 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { checkSummaryAvail() cachedSummary.put(objectId, obj.asInstanceOf[Summary]) } else if (obj.isInstanceOf[Model[_]]) { - val sizeBytes = if (getOffloadingEnabled) { - estimateObjectSize(obj) + val sizeBytes = if (getMemoryControlEnabled) { + val _sizeBytes = estimateObjectSize(obj) + checkModelSize(_sizeBytes) + _sizeBytes } else { - 0L // Don't need to calculate size if disables offloading. + 0L // Don't need to calculate size if disables memory-control. } cachedModel.put(objectId, CacheItem(obj, sizeBytes)) - if (getOffloadingEnabled) { + if (getMemoryControlEnabled) { val savePath = offloadedModelsDir.resolve(objectId) obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString) Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName) + totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes) + totalMLCacheSizeBytes.addAndGet(sizeBytes) } - totalSizeBytes.addAndGet(sizeBytes) } else { throw new RuntimeException("'MLCache.register' only accepts model or summary objects.") } @@ -155,7 +177,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } else { var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId)) - if (obj == null && getOffloadingEnabled) { + if (obj == null && getMemoryControlEnabled) { val loadPath = offloadedModelsDir.resolve(refId) if (Files.isDirectory(loadPath)) { val className = Files.readString(loadPath.resolve(modelClassNameFile)) @@ -166,7 +188,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { loadFromLocal = true) val sizeBytes = estimateObjectSize(obj) cachedModel.put(refId, CacheItem(obj, sizeBytes)) - totalSizeBytes.addAndGet(sizeBytes) + totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes) } } obj @@ -176,7 +198,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def _removeModel(refId: String): Boolean = { val removedModel = cachedModel.remove(refId) val removedFromMem = removedModel != null - val removedFromDisk = if (getOffloadingEnabled) { + val removedFromDisk = if (getMemoryControlEnabled) { + totalMLCacheSizeBytes.addAndGet(-removedModel.sizeBytes) val offloadingPath = new File(offloadedModelsDir.resolve(refId).toString) if (offloadingPath.exists()) { FileUtils.deleteDirectory(offloadingPath) @@ -213,7 +236,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val size = cachedModel.size() cachedModel.clear() cachedSummary.clear() - if (getOffloadingEnabled) { + if (getMemoryControlEnabled) { FileUtils.cleanDirectory(new File(offloadedModelsDir.toString)) } size diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala index d1a7f232edf7c..4889e74d1e402 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala @@ -36,3 +36,16 @@ private[spark] case class MLCacheInvalidException(objectName: String) errorClass = "CONNECT_ML.CACHE_INVALID", messageParameters = Map("objectName" -> objectName), cause = null) + +private[spark] case class MLModelSizeOverflowException(modelSize: Long, modelMaxSize: Long) + extends SparkException( + errorClass = "CONNECT_ML.MODEL_SIZE_OVERFLOW_EXCEPTION", + messageParameters = + Map("modelSize" -> modelSize.toString, "modelMaxSize" -> modelMaxSize.toString), + cause = null) + +private[spark] case class MLCacheSizeOverflowException(mlCacheMaxSize: Long) + extends SparkException( + errorClass = "CONNECT_ML.ML_CACHE_SIZE_OVERFLOW_EXCEPTION", + messageParameters = Map("mlCacheMaxSize" -> mlCacheMaxSize.toString), + cause = null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 44595f3418318..3afc15c97a66d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -24,10 +24,10 @@ import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.ml.Model import org.apache.spark.ml.param.{ParamMap, Params} -import org.apache.spark.ml.util.{MLWritable, Summary} +import org.apache.spark.ml.tree.TreeConfig +import org.apache.spark.ml.util.{MLWritable, Summary, SummaryUtils} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.common.LiteralValueProtoConverter -import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments import org.apache.spark.sql.connect.service.SessionHolder @@ -120,11 +120,27 @@ private[connect] object MLHandler extends Logging { mlCommand: proto.MlCommand): proto.MlCommandResult = { val mlCache = sessionHolder.mlCache + val memoryControlEnabled = sessionHolder.mlCache.getMemoryControlEnabled + + // Disable model training summary when memory control is enabled + // because training summary can't support + // size estimation and offloading. + SummaryUtils.enableTrainingSummary = !memoryControlEnabled + + if (memoryControlEnabled) { + val maxModelSize = sessionHolder.mlCache.getModelMaxSize + + // Note: Tree training stops early when the growing tree model exceeds + // `TreeConfig.trainingEarlyStopModelSizeThresholdInBytes`, to ensure the final + // model size is lower than `maxModelSize`, set early-stop threshold to + // half of `maxModelSize`, because in each tree training iteration, the tree + // nodes will grow up to 2 times, the additional 0.5 is for buffer + // because the in-memory size is not exactly in direct proportion to the tree nodes. + TreeConfig.trainingEarlyStopModelSizeThresholdInBytes = (maxModelSize.toDouble / 2.5).toLong + } mlCommand.getCommandCase match { case proto.MlCommand.CommandCase.FIT => - val offloadingEnabled = sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED) val fitCmd = mlCommand.getFit val estimatorProto = fitCmd.getEstimator assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) @@ -132,7 +148,14 @@ private[connect] object MLHandler extends Logging { val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder) val estimator = MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams)) - if (offloadingEnabled) { + + if (memoryControlEnabled) { + try { + val estimatedModelSize = estimator.estimateModelSize(dataset) + mlCache.checkModelSize(estimatedModelSize) + } catch { + case _: UnsupportedOperationException => () + } if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") { throw MlUnsupportedException( "FPGrowth algorithm is not supported " + diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala index 5939b673501b5..802dc46c23a00 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala @@ -93,6 +93,35 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { createLocalRelationProto(schema, inputRows) } + def createRelationProtoForTreeModel(numFeatures: Int, numInstances: Int): proto.Relation = { + val udt = new VectorUDT() + val rows = new Array[InternalRow](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0f + } else if (i < numInstances / 2) { + 1.0f + } else if (i < numInstances * 0.9) { + 0.0f + } else { + 1.0f + } + val features = Array.tabulate[Double](numFeatures)(index => + scala.math.cos(i.toDouble * scala.math.pow(2, index))) + rows(i) = InternalRow(label, udt.serialize(Vectors.dense(features))) + } + val schema = StructType( + Seq( + StructField("label", FloatType), + StructField("features", new VectorUDT(), false, Metadata.empty))) + + val inputRows = rows.toSeq.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + createLocalRelationProto(DataTypeUtils.toAttributes(schema), inputRows, "UTC", Some(schema)) + } + def getLogisticRegression: proto.MlOperator.Builder = proto.MlOperator .newBuilder() @@ -100,11 +129,37 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { .setUid("LogisticRegression") .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) + def getDecisionTreeClassifier: proto.MlOperator.Builder = + proto.MlOperator + .newBuilder() + .setName("org.apache.spark.ml.classification.DecisionTreeClassifier") + .setUid("DecisionTreeClassifier") + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) + + def getRandomForestClassifier: proto.MlOperator.Builder = + proto.MlOperator + .newBuilder() + .setName("org.apache.spark.ml.classification.RandomForestClassifier") + .setUid("RandomForestClassifier") + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) + + def getGBTClassifier: proto.MlOperator.Builder = + proto.MlOperator + .newBuilder() + .setName("org.apache.spark.ml.classification.GBTClassifier") + .setUid("GBTClassifier") + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) + def getMaxIter: proto.MlParams.Builder = proto.MlParams .newBuilder() .putParams("maxIter", proto.Expression.Literal.newBuilder().setInteger(2).build()) + def getMaxDepth(maxDepth: Int): proto.MlParams.Builder = + proto.MlParams + .newBuilder() + .putParams("maxDepth", proto.Expression.Literal.newBuilder().setInteger(maxDepth).build()) + def getRegressorEvaluator: proto.MlOperator.Builder = proto.MlOperator .newBuilder() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 36b198b0db31a..fcb721b51e64d 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -136,7 +136,7 @@ class MLSuite extends MLHelper { test("LogisticRegression works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "false") + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "false") // estimator read/write val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter) @@ -262,7 +262,7 @@ class MLSuite extends MLHelper { test("Exception: cannot retrieve object") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "false") + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "false") val modelId = trainLogisticRegressionModel(sessionHolder) // Fetch summary attribute @@ -385,28 +385,63 @@ class MLSuite extends MLHelper { .toArray sameElements Array("a", "b", "c")) } + test("tree model training early stop for limiting model size") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "true") + + for (estimator <- Seq(getDecisionTreeClassifier, getRandomForestClassifier)) { + for (maxModelSize <- Seq(20000, 50000)) { + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key, + maxModelSize.toString) + trainTreeModel(sessionHolder, estimator) + val lastModelSize = org.apache.spark.ml.tree.impl.RandomForest.lastEarlyStoppedModelSize + assert(lastModelSize < maxModelSize) + assert(lastModelSize >= maxModelSize.toDouble / 2.5) + } + } + } + + test("GBT tree model training early stop for limiting model size") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "true") + + for (maxModelSize <- Seq(20000, 50000, 130000)) { + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key, + maxModelSize.toString) + trainTreeModel(sessionHolder, getGBTClassifier) + val lastModelSize = + org.apache.spark.ml.tree.impl.GradientBoostedTrees.lastEarlyStoppedModelSize + assert(lastModelSize < maxModelSize) + assert(lastModelSize >= maxModelSize.toDouble / 2.5) + } + } + test("MLCache offloading works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "true") + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "true") val memorySizeBytes = 1024 * 16 sessionHolder.session.conf.set( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE.key, + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_IN_MEMORY_SIZE.key, memorySizeBytes) val modelIdList = scala.collection.mutable.ListBuffer[String]() modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == 1) - assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) - val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get() + assert(sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() > 0) + val modelSizeBytes = sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() val maxNumModels = memorySizeBytes / modelSizeBytes.toInt // All models will be kept if the total size is less than the memory limit. for (i <- 1 until maxNumModels) { modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == i + 1) - assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) - assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) + assert(sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() > 0) + assert(sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() <= memorySizeBytes) } // Old models will be offloaded @@ -414,8 +449,8 @@ class MLSuite extends MLHelper { for (_ <- 0 until 3) { modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels) - assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) - assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) + assert(sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() > 0) + assert(sessionHolder.mlCache.totalMLCacheInMemorySizeBytes.get() <= memorySizeBytes) } // Assert all models can be loaded back from disk after they are offloaded. @@ -423,4 +458,37 @@ class MLSuite extends MLHelper { assert(sessionHolder.mlCache.get(modelId) != null) } } + + test("Model size limit") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key, "4000") + intercept[MLModelSizeOverflowException] { + trainLogisticRegressionModel(sessionHolder) + } + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key, "8000") + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE.key, "10000") + trainLogisticRegressionModel(sessionHolder) + intercept[MLCacheSizeOverflowException] { + trainLogisticRegressionModel(sessionHolder) + } + } + + def trainTreeModel( + sessionHolder: SessionHolder, + estimator: proto.MlOperator.Builder): String = { + val fitCommand = proto.MlCommand + .newBuilder() + .setFit( + proto.MlCommand.Fit + .newBuilder() + .setDataset(createRelationProtoForTreeModel(128, 10000)) + .setEstimator(estimator) + .setParams(getMaxDepth(30))) + .build() + val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand) + fitResult.getOperatorInfo.getObjRef.getId + } }