Skip to content

Commit 0826627

Browse files
committed
Fixes
1 parent 40b2f0a commit 0826627

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

python-package/xgboost/spark/core.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,13 @@ def _validate_gpu_params(
382382
self.getOrDefault(self.num_workers),
383383
)
384384
else:
385-
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
385+
executor_gpus = conf.get("spark.executor.resource.gpu.amount", None)
386386
if executor_gpus is None:
387387
raise ValueError(
388388
"The `spark.executor.resource.gpu.amount` is required for training"
389389
" on GPU."
390390
)
391-
gpu_per_task = conf.get("spark.task.resource.gpu.amount")
391+
gpu_per_task = conf.get("spark.task.resource.gpu.amount", None)
392392
if gpu_per_task is not None and float(gpu_per_task) > 1.0:
393393
get_logger(self.__class__.__name__).warning(
394394
"The configuration assigns %s GPUs to each Spark task, but each "
@@ -546,7 +546,7 @@ def _validate_and_convert_feature_col_as_array_col(
546546
else:
547547
raise ValueError(
548548
"feature column must be array type or `pyspark.ml.linalg.Vector` type, "
549-
"if you want to use multiple numetric columns as features, please use "
549+
"if you want to use multiple numeric columns as features, please use "
550550
"`pyspark.ml.transform.VectorAssembler` to assemble them into a vector "
551551
"type column first."
552552
)
@@ -885,7 +885,9 @@ def _get_xgb_parameters(
885885
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
886886
train_params
887887
)
888-
cpu_per_task = int(_get_spark_session().conf.get("spark.task.cpus") or "1")
888+
cpu_per_task = int(
889+
_get_spark_session().conf.get("spark.task.cpus", None) or "1"
890+
)
889891

890892
dmatrix_kwargs = {
891893
"nthread": cpu_per_task,
@@ -932,8 +934,8 @@ def _skip_stage_level_scheduling(
932934
)
933935
return True
934936

935-
executor_cores = conf.get("spark.executor.cores")
936-
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
937+
executor_cores = conf.get("spark.executor.cores", None)
938+
executor_gpus = conf.get("spark.executor.resource.gpu.amount", None)
937939
if executor_cores is None or executor_gpus is None:
938940
self.logger.info(
939941
"Stage-level scheduling in xgboost requires spark.executor.cores, "
@@ -958,7 +960,7 @@ def _skip_stage_level_scheduling(
958960
)
959961
return True
960962

961-
task_gpu_amount = conf.get("spark.task.resource.gpu.amount")
963+
task_gpu_amount = conf.get("spark.task.resource.gpu.amount", None)
962964

963965
if task_gpu_amount is None:
964966
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
@@ -986,7 +988,7 @@ def _get_resource_profile_for_stage_level_scheduling(
986988
return None
987989

988990
# executor_cores will not be None
989-
executor_cores = conf.get("spark.executor.cores")
991+
executor_cores = conf.get("spark.executor.cores", None)
990992
assert executor_cores is not None
991993

992994
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
@@ -1033,7 +1035,7 @@ def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]:
10331035

10341036
if conf.tracker_host_ip is None:
10351037
conf.tracker_host_ip = _get_spark_session().conf.get(
1036-
"spark.driver.host"
1038+
"spark.driver.host", None
10371039
)
10381040
num_workers = self.getOrDefault(self.num_workers)
10391041
rabit_args.update(_get_rabit_args(conf, num_workers))
@@ -1390,7 +1392,9 @@ def _run_on_gpu(self) -> bool:
13901392
# if it's local model, no need to check the spark configurations
13911393
return use_gpu_by_params
13921394

1393-
gpu_per_task = _get_spark_session().conf.get("spark.task.resource.gpu.amount")
1395+
gpu_per_task = _get_spark_session().conf.get(
1396+
"spark.task.resource.gpu.amount", None
1397+
)
13941398

13951399
# User don't set gpu configurations, just use cpu
13961400
if gpu_per_task is None:

tests/test_distributed/test_with_spark/test_spark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,16 +1236,19 @@ def test_collective_conf(self, spark: SparkSession, tmp_path: Path) -> None:
12361236
):
12371237
classifier._get_tracker_args()
12381238

1239+
avail_tracker_port = get_avail_port()
12391240
classifier = SparkXGBClassifier(
12401241
launch_tracker_on_driver=True,
1241-
coll_cfg=Config(tracker_host_ip="127.0.0.1", tracker_port=58893),
1242+
coll_cfg=Config(
1243+
tracker_host_ip="127.0.0.1", tracker_port=avail_tracker_port
1244+
),
12421245
num_workers=2,
12431246
)
12441247
launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args()
12451248
assert launch_tracker_on_driver is True
12461249
assert rabit_envs["n_workers"] == 2
12471250
assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1"
1248-
assert rabit_envs["dmlc_tracker_port"] == 58893
1251+
assert rabit_envs["dmlc_tracker_port"] == avail_tracker_port
12491252

12501253
path = "file:" + str(tmp_path)
12511254
port = get_avail_port()

0 commit comments

Comments
 (0)