Skip to content

Commit 43f78ae

Browse files
WeichenXu123yhuang-db
authored andcommitted
[SPARK-51974][CONNECT][ML] Limit model size and per-session model cache size
### What changes were proposed in this pull request? Limit model size and per-session model cache size. Configurations: ``` spark.connect.session.connectML.mlCache.memoryControl.enabled spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout spark.connect.session.connectML.mlCache.memoryControl.maxModelSize spark.connect.session.connectML.mlCache.memoryControl.maxSize ``` Only when spark.connect.session.connectML.mlCache.memoryControl.enabled is true, the rest configures can work. ### Why are the changes needed? Motivation: This is for ML cache management, to avoid huge ML cache affecting Spark driver availability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50751 from WeichenXu123/tree-early-stop. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 344e6c0 commit 43f78ae

File tree

28 files changed

+502
-166
lines changed

28 files changed

+502
-166
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,21 @@
811811
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
812812
]
813813
},
814+
"ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : {
815+
"message" : [
816+
"The model cache size in current session is about to exceed",
817+
"<mlCacheMaxSize> bytes.",
818+
"Please delete existing cached model by executing 'del model' in python client before fitting new model or loading new model,",
819+
"or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxSize'."
820+
]
821+
},
822+
"MODEL_SIZE_OVERFLOW_EXCEPTION" : {
823+
"message" : [
824+
"The fitted or loaded model size is about <modelSize> bytes.",
825+
"Please fit or load a model smaller than <modelMaxSize> bytes,",
826+
"or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxModelSize'."
827+
]
828+
},
814829
"UNSUPPORTED_EXCEPTION" : {
815830
"message" : [
816831
"<message>"

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class DecisionTreeClassificationModel private[ml] (
195195
// For ml connect only
196196
private[ml] def this() = this("", Node.dummyNode, -1, -1)
197197

198+
override def estimatedSize: Long = getEstimatedSize()
199+
198200
override def predict(features: Vector): Double = {
199201
rootNode.predictImpl(features).prediction
200202
}

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,18 @@ class FMClassifier @Since("3.0.0") (
224224
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
225225
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
226226

227-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
228-
val summary = new FMClassificationTrainingSummaryImpl(
229-
summaryModel.transform(dataset),
230-
probabilityColName,
231-
predictionColName,
232-
$(labelCol),
233-
weightColName,
234-
objectiveHistory)
235-
model.setSummary(Some(summary))
227+
if (SummaryUtils.enableTrainingSummary) {
228+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
229+
val summary = new FMClassificationTrainingSummaryImpl(
230+
summaryModel.transform(dataset),
231+
probabilityColName,
232+
predictionColName,
233+
$(labelCol),
234+
weightColName,
235+
objectiveHistory)
236+
model.setSummary(Some(summary))
237+
}
238+
model
236239
}
237240

238241
@Since("3.0.0")

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ class GBTClassificationModel private[ml](
276276
private[ml] def this() = this("",
277277
Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1, -1)
278278

279+
override def estimatedSize: Long = getEstimatedSize()
280+
279281
@Since("1.4.0")
280282
override def trees: Array[DecisionTreeRegressionModel] = _trees
281283

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,18 @@ class LinearSVC @Since("2.2.0") (
277277
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
278278
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
279279

280-
val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
281-
val summary = new LinearSVCTrainingSummaryImpl(
282-
summaryModel.transform(dataset),
283-
rawPredictionColName,
284-
predictionColName,
285-
$(labelCol),
286-
weightColName,
287-
objectiveHistory)
288-
model.setSummary(Some(summary))
280+
if (SummaryUtils.enableTrainingSummary) {
281+
val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
282+
val summary = new LinearSVCTrainingSummaryImpl(
283+
summaryModel.transform(dataset),
284+
rawPredictionColName,
285+
predictionColName,
286+
$(labelCol),
287+
weightColName,
288+
objectiveHistory)
289+
model.setSummary(Some(summary))
290+
}
291+
model
289292
}
290293

291294
private def trainImpl(

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -711,27 +711,30 @@ class LogisticRegression @Since("1.2.0") (
711711
numClasses, checkMultinomial(numClasses)))
712712
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
713713

714-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
715-
val logRegSummary = if (numClasses <= 2) {
716-
new BinaryLogisticRegressionTrainingSummaryImpl(
717-
summaryModel.transform(dataset),
718-
probabilityColName,
719-
predictionColName,
720-
$(labelCol),
721-
$(featuresCol),
722-
weightColName,
723-
objectiveHistory)
724-
} else {
725-
new LogisticRegressionTrainingSummaryImpl(
726-
summaryModel.transform(dataset),
727-
probabilityColName,
728-
predictionColName,
729-
$(labelCol),
730-
$(featuresCol),
731-
weightColName,
732-
objectiveHistory)
714+
if (SummaryUtils.enableTrainingSummary) {
715+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
716+
val logRegSummary = if (numClasses <= 2) {
717+
new BinaryLogisticRegressionTrainingSummaryImpl(
718+
summaryModel.transform(dataset),
719+
probabilityColName,
720+
predictionColName,
721+
$(labelCol),
722+
$(featuresCol),
723+
weightColName,
724+
objectiveHistory)
725+
} else {
726+
new LogisticRegressionTrainingSummaryImpl(
727+
summaryModel.transform(dataset),
728+
probabilityColName,
729+
predictionColName,
730+
$(labelCol),
731+
$(featuresCol),
732+
weightColName,
733+
objectiveHistory)
734+
}
735+
model.setSummary(Some(logRegSummary))
733736
}
734-
model.setSummary(Some(logRegSummary))
737+
model
735738
}
736739

737740
private def createBounds(

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,17 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
249249
objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = {
250250
val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights))
251251

252-
val (summaryModel, _, predictionColName) = model.findSummaryModel()
253-
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
254-
summaryModel.transform(dataset),
255-
predictionColName,
256-
$(labelCol),
257-
"",
258-
objectiveHistory)
259-
model.setSummary(Some(summary))
252+
if (SummaryUtils.enableTrainingSummary) {
253+
val (summaryModel, _, predictionColName) = model.findSummaryModel()
254+
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
255+
summaryModel.transform(dataset),
256+
predictionColName,
257+
$(labelCol),
258+
"",
259+
objectiveHistory)
260+
model.setSummary(Some(summary))
261+
}
262+
model
260263
}
261264
}
262265

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,23 +185,26 @@ class RandomForestClassifier @Since("1.4.0") (
185185
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
186186

187187
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
188-
val rfSummary = if (numClasses <= 2) {
189-
new BinaryRandomForestClassificationTrainingSummaryImpl(
190-
summaryModel.transform(dataset),
191-
probabilityColName,
192-
predictionColName,
193-
$(labelCol),
194-
weightColName,
195-
Array(0.0))
196-
} else {
197-
new RandomForestClassificationTrainingSummaryImpl(
198-
summaryModel.transform(dataset),
199-
predictionColName,
200-
$(labelCol),
201-
weightColName,
202-
Array(0.0))
188+
if (SummaryUtils.enableTrainingSummary) {
189+
val rfSummary = if (numClasses <= 2) {
190+
new BinaryRandomForestClassificationTrainingSummaryImpl(
191+
summaryModel.transform(dataset),
192+
probabilityColName,
193+
predictionColName,
194+
$(labelCol),
195+
weightColName,
196+
Array(0.0))
197+
} else {
198+
new RandomForestClassificationTrainingSummaryImpl(
199+
summaryModel.transform(dataset),
200+
predictionColName,
201+
$(labelCol),
202+
weightColName,
203+
Array(0.0))
204+
}
205+
model.setSummary(Some(rfSummary))
203206
}
204-
model.setSummary(Some(rfSummary))
207+
model
205208
}
206209

207210
@Since("1.4.1")
@@ -258,6 +261,8 @@ class RandomForestClassificationModel private[ml] (
258261
// For ml connect only
259262
private[ml] def this() = this("", Array(new DecisionTreeClassificationModel), -1, -1)
260263

264+
override def estimatedSize: Long = getEstimatedSize()
265+
261266
@Since("1.4.0")
262267
override def trees: Array[DecisionTreeClassificationModel] = _trees
263268

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,19 @@ class BisectingKMeans @Since("2.0.0") (
303303
val parentModel = bkm.runWithWeight(instances, handlePersistence, Some(instr))
304304
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
305305

306-
val summary = new BisectingKMeansSummary(
307-
model.transform(dataset),
308-
$(predictionCol),
309-
$(featuresCol),
310-
$(k),
311-
$(maxIter),
312-
parentModel.trainingCost)
313-
instr.logNamedValue("clusterSizes", summary.clusterSizes)
314-
instr.logNumFeatures(model.clusterCenters.head.size)
315-
model.setSummary(Some(summary))
306+
if (SummaryUtils.enableTrainingSummary) {
307+
val summary = new BisectingKMeansSummary(
308+
model.transform(dataset),
309+
$(predictionCol),
310+
$(featuresCol),
311+
$(k),
312+
$(maxIter),
313+
parentModel.trainingCost)
314+
instr.logNamedValue("clusterSizes", summary.clusterSizes)
315+
instr.logNumFeatures(model.clusterCenters.head.size)
316+
model.setSummary(Some(summary))
317+
}
318+
model
316319
}
317320

318321
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,11 +430,14 @@ class GaussianMixture @Since("2.0.0") (
430430

431431
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists))
432432
.setParent(this)
433-
val summary = new GaussianMixtureSummary(model.transform(dataset),
434-
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration)
435-
instr.logNamedValue("logLikelihood", logLikelihood)
436-
instr.logNamedValue("clusterSizes", summary.clusterSizes)
437-
model.setSummary(Some(summary))
433+
if (SummaryUtils.enableTrainingSummary) {
434+
val summary = new GaussianMixtureSummary(model.transform(dataset),
435+
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration)
436+
instr.logNamedValue("logLikelihood", logLikelihood)
437+
instr.logNamedValue("clusterSizes", summary.clusterSizes)
438+
model.setSummary(Some(summary))
439+
}
440+
model
438441
}
439442

440443
private def trainImpl(

0 commit comments

Comments
 (0)