Skip to content

Commit 804c5e8

Browse files
authored
[jvm-packages] Add setNumEarlyStoppingRounds (#11571)
1 parent 92e2c97 commit 804c5e8

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
122122
final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost",
123123
ParamValidators.gtEq(1))
124124

125-
final def getNumRound: Int = $(numRound)
125+
final def getNumWorkers: Int = $(numWorkers)
126126

127127
final val forceRepartition = new BooleanParam(this, "forceRepartition", "If the partition " +
128128
"is equal to numWorkers, xgboost won't repartition the dataset. Set forceRepartition to " +
@@ -133,6 +133,8 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
133133
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
134134
ParamValidators.gtEq(1))
135135

136+
final def getNumRound: Int = $(numRound)
137+
136138
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", "Stop training " +
137139
"Number of rounds of decreasing eval metric to tolerate before stopping training",
138140
ParamValidators.gtEq(0))
@@ -216,14 +218,14 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
216218
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
217219
forceRepartition, featuresCols, customEval, customObj, featureTypes, featureNames)
218220

219-
final def getNumWorkers: Int = $(numWorkers)
220-
221221
def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]
222222

223223
def setForceRepartition(value: Boolean): T = set(forceRepartition, value).asInstanceOf[T]
224224

225225
def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T]
226226

227+
def setNumEarlyStoppingRounds(value: Int): T = set(numEarlyStoppingRounds, value).asInstanceOf[T]
228+
227229
def setFeaturesCol(value: Array[String]): T = set(featuresCols, value).asInstanceOf[T]
228230

229231
def setBaseMarginCol(value: String): T = set(baseMarginCol, value).asInstanceOf[T]

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostParamsSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.scalatest.funsuite.AnyFunSuite
2424
class XGBoostParamsSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
2525

2626
test("invalid parameters") {
27-
val df = smallBinaryClassificationVector
2827
val estimator = new XGBoostClassifier()
2928

3029
// We didn't set it by default
@@ -52,4 +51,11 @@ class XGBoostParamsSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite
5251
assert(v1 == 0.66f)
5352
}
5453

54+
test("setNumEarlyStoppingRounds") {
55+
val estimator = new XGBoostClassifier()
56+
assert(estimator.getNumEarlyStoppingRounds == 0)
57+
estimator.setNumEarlyStoppingRounds(10)
58+
assert(estimator.getNumEarlyStoppingRounds == 10)
59+
}
60+
5561
}

0 commit comments

Comments
 (0)