Skip to content

Commit 1dc3bea

Browse files
committed
Fixes.
1 parent 1a0ad29 commit 1dc3bea

File tree

3 files changed

+11
-21
lines changed

3 files changed

+11
-21
lines changed

python-package/xgboost/spark/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,8 @@ def _validate_params(self) -> None:
503503
def _run_on_gpu(self) -> bool:
504504
"""If train or transform on the gpu according to the parameters"""
505505

506-
return (
507-
use_cuda(self.getOrDefault(self.device))
508-
or self.getOrDefault(self.use_gpu)
506+
return use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
507+
self.use_gpu
509508
)
510509

511510
def _col_is_defined_not_empty(self, param: "Param[str]") -> bool:

tests/cpp/gbm/test_gbtree.cc

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,26 +218,17 @@ TEST(GBTree, ChooseTreeMethod) {
218218
return updater;
219219
};
220220

221-
// | | hist | gpu_hist | exact | NA |
222-
// |--------+---------+----------+-------+-----|
223-
// | CUDA:0 | GPU | GPU (w) | Err | GPU |
224-
// | CPU | CPU | GPU (w) | CPU | CPU |
225-
// |--------+---------+----------+-------+-----|
226-
// | -1 | CPU | GPU (w) | CPU | CPU |
227-
// | 0 | GPU | GPU (w) | Err | GPU |
228-
// |--------+---------+----------+-------+-----|
229-
// | NA | CPU | GPU (w) | CPU | CPU |
221+
// | | hist | approx | exact | NA |
222+
// |--------+---------+--------+-------+-----|
223+
// | CUDA:0 | GPU | GPU | Err | GPU |
224+
// | CPU | CPU | GPU | CPU | CPU |
225+
// |--------+---------+--------+-------+-----|
226+
// | NA | CPU | CPU | CPU | CPU |
230227
//
231-
// - (w): warning
232228
// - CPU: Run on CPU.
233229
// - GPU: Run on CUDA.
234230
// - Err: Not feasible.
235231
// - NA: Parameter is not specified.
236-
237-
// When GPU hist is specified with a CPU context, we should emit an error. However, it's
238-
// quite difficult to detect whether the CPU context is being used because it's the
239-
// default or because it's specified by the user.
240-
241232
std::map<std::pair<std::optional<std::string>, std::optional<std::string>>, std::string>
242233
expectation{
243234
// hist
@@ -246,10 +237,10 @@ TEST(GBTree, ChooseTreeMethod) {
246237
{{"hist", "cuda:0"}, "grow_gpu_hist"},
247238
{{"hist", std::nullopt}, "grow_quantile_histmaker"},
248239
// approx
249-
{{"approx", "cpu"}, "grow_gpu_approx"},
240+
{{"approx", "cpu"}, "grow_histmaker"},
250241
{{"approx", "cuda"}, "grow_gpu_approx"},
251242
{{"approx", "cuda:0"}, "grow_gpu_approx"},
252-
{{"approx", std::nullopt}, "grow_gpu_approx"},
243+
{{"approx", std::nullopt}, "grow_histmaker"},
253244
// exact
254245
{{"exact", "cpu"}, "grow_colmaker,prune"},
255246
{{"exact", "cuda"}, "err"},

tests/test_distributed/test_with_spark/test_spark_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def test_gpu_params(self) -> None:
942942
assert clf._run_on_gpu()
943943

944944
clf = SparkXGBClassifier(tree_method="hist")
945-
assert clf._run_on_gpu()
945+
assert not clf._run_on_gpu()
946946

947947
clf = SparkXGBClassifier(use_gpu=True)
948948
assert clf._run_on_gpu()

0 commit comments

Comments
 (0)