Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,22 @@
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
]
},
"MODEL_SIZE_OVERFLOW_EXCEPTION" : {
"message" : [
"The fitted or loaded model size exceeds <modelMaxSize> bytes. ",
"Please fit or load a model smaller than <modelMaxSize> bytes, ",
"or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxModelSize'."
]
},
"ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : {
"message" : [
"The model cache size in current session is about to exceed",
"<mlCacheMaxSize> bytes. ",
"Please delete existing cached model by executing 'del model' in python client ",
"before fitting new model or loading new model, or increase ",
"Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxSize'."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.SizeEstimator


private[spark] object GradientBoostedTrees extends Logging {
Expand Down Expand Up @@ -292,6 +293,7 @@ private[spark] object GradientBoostedTrees extends Logging {
featureSubsetStrategy: String,
instr: Option[Instrumentation] = None):
(Array[DecisionTreeRegressionModel], Array[Double]) = {
val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
Expand Down Expand Up @@ -365,7 +367,8 @@ private[spark] object GradientBoostedTrees extends Logging {
val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None)
parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)
.head.asInstanceOf[DecisionTreeRegressionModel]

firstCounts.unpersist()
Expand Down Expand Up @@ -400,11 +403,18 @@ private[spark] object GradientBoostedTrees extends Logging {
timer.stop("init validation")
}

var bestM = 1
var accTreeSize = SizeEstimator.estimate(firstTreeModel)
var validM = 1

var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
var earlyStop = false
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
}
while (m < numIterations && !earlyStop) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
Expand Down Expand Up @@ -434,14 +444,16 @@ private[spark] object GradientBoostedTrees extends Logging {
val model = RandomForest.runBagged(baggedInput = bagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
seed = seed + m, instr = None, parentUID = None)
seed = seed + m, instr = None, parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes - accTreeSize)
.head.asInstanceOf[DecisionTreeRegressionModel]

labelWithCounts.unpersist()

timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
accTreeSize += SizeEstimator.estimate(model)
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
Expand All @@ -466,10 +478,19 @@ private[spark] object GradientBoostedTrees extends Logging {
val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
earlyStop = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
validM = m + 1
}
}
if (!earlyStop) {
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
validM = m + 1
}
}
m += 1
Expand All @@ -490,8 +511,16 @@ private[spark] object GradientBoostedTrees extends Logging {
validatePredErrorCheckpointer.deleteAllCheckpoints()
}

if (validate) {
(baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
if (earlyStop) {
// Early stop occurs in one of the 2 cases:
// - validation error increases
// - the accumulated size of trees exceeds the value of `earlyStopModelSizeThresholdInBytes`
if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
logWarning(
"The boosting tree training stops early because the model size exceeds threshold."
)
}
(baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
} else {
(baseLearners, baseLearnerWeights)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.SizeEstimator
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}

Expand Down Expand Up @@ -109,6 +110,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* @param metadata Learning and dataset metadata for DecisionTree.
* @return an unweighted set of trees
*/
// scalastyle:off
def runBagged(
baggedInput: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
Expand All @@ -119,7 +121,8 @@ private[spark] object RandomForest extends Logging with Serializable {
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
parentUID: Option[String] = None,
earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
timer.start("total")

Expand Down Expand Up @@ -190,7 +193,8 @@ private[spark] object RandomForest extends Logging with Serializable {

timer.stop("init")

while (nodeStack.nonEmpty) {
var earlyStop = false
while (nodeStack.nonEmpty && !earlyStop) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
Expand All @@ -214,6 +218,16 @@ private[spark] object RandomForest extends Logging with Serializable {
}

timer.stop("findBestSplits")

if (
earlyStopModelSizeThresholdInBytes > 0
&& SizeEstimator.estimate(topNodes) > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
logWarning(
"The random forest training stops early because the model size exceeds threshold."
)
}
}

timer.stop("total")
Expand Down Expand Up @@ -253,6 +267,7 @@ private[spark] object RandomForest extends Logging with Serializable {
}
}
}
// scalastyle:on

/**
* Train a random forest.
Expand All @@ -269,6 +284,7 @@ private[spark] object RandomForest extends Logging with Serializable {
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
val timer = new TimeTracker()

timer.start("build metadata")
Expand Down Expand Up @@ -301,7 +317,8 @@ private[spark] object RandomForest extends Logging with Serializable {

val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
seed = seed, instr = instr, prune = prune, parentUID = parentUID)
seed = seed, instr = instr, prune = prune, parentUID = parentUID,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)

baggedInput.unpersist()
bcSplits.destroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,7 @@ private[ml] object EnsembleModelReadWrite {
}
}
}

private[spark] object TreeConfig {
private[spark] var trainingEarlyStopModelSizeThresholdInBytes: Long = 0L
}
6 changes: 3 additions & 3 deletions python/pyspark/ml/tests/connect/test_parity_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConn
@classmethod
def conf(cls):
conf = super().conf()
conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "true")
conf.set("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize", "1024")
conf.set("spark.connect.session.connectML.mlCache.offloading.timeout", "1")
conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "true")
conf.set("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize", "1024")
conf.set("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout", "1")
return conf


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,36 +334,51 @@ object Connect {
}
}

val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE =
buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize")
val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED =
buildConf("spark.connect.session.connectML.mlCache.memoryControl.enabled")
.doc("Enables ML cache memory control, it includes offloading model to disk, " +
"limiting model size, and limiting per-session ML cache size.")
.version("4.1.0")
.internal()
.booleanConf
.createWithDefault(true)

val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_IN_MEMORY_SIZE =
buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize")
.doc(
"In-memory maximum size of the MLCache per session. The cache will offload the least " +
"Maximum in-memory size of the MLCache per session. The cache will offload the least " +
"recently used models to Spark driver local disk if the size exceeds this limit. " +
"The size is in bytes. This configuration only works when " +
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
"The size is in bytes.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
// By default, 1/3 of total designated memory (the configured -Xmx).
.createWithDefault(Runtime.getRuntime.maxMemory() / 3)

val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT =
buildConf("spark.connect.session.connectML.mlCache.offloading.timeout")
val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_OFFLOADING_TIMEOUT =
buildConf("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout")
.doc(
"Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " +
"disk if they are not used for this amount of time. The timeout is in minutes. " +
"This configuration only works when " +
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
"disk if they are not used for this amount of time. The timeout is in minutes.")
.version("4.1.0")
.internal()
.timeConf(TimeUnit.MINUTES)
.createWithDefault(5)

val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED =
buildConf("spark.connect.session.connectML.mlCache.offloading.enabled")
.doc("Enables ML cache offloading.")
val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE =
buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxModelSize")
.doc("Maximum size of a single SparkML model. The size is in bytes.")
.version("4.1.0")
.internal()
.booleanConf
.createWithDefault(true)
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("1g")

val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE =
buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxSize")
.doc("Maximum total size (including in-memory and offloaded data) of the ml cache. " +
"The size is in bytes.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("10g")
}
Loading