Skip to content

Commit 9c0efce

Browse files
imeyer2trivialfis
andauthored
Add enable_categorical to the apply method (#11550)
--------- Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 5ff47c1 commit 9c0efce

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python-package/xgboost/sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,7 @@ def apply(
14161416
missing=self.missing,
14171417
feature_types=self.feature_types,
14181418
nthread=self.n_jobs,
1419+
enable_categorical=self.enable_categorical,
14191420
)
14201421
return self.get_booster().predict(
14211422
test_dmatrix, pred_leaf=True, iteration_range=iteration_range

tests/python/test_with_sklearn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,3 +1567,23 @@ def test_doc_link() -> None:
15671567
name = est.__class__.__name__
15681568
link = est._get_doc_link()
15691569
assert f"xgboost.{name}" in link
1570+
1571+
1572+
def test_apply_method():
1573+
import pandas as pd
1574+
1575+
X_num = np.random.rand(5, 5)
1576+
df = pd.DataFrame(X_num, columns=[f"f{i}" for i in range(X_num.shape[1])])
1577+
df["test"] = pd.Series(
1578+
["one", "two", "three", "four", "five"], dtype="category"
1579+
) # <- categorical column
1580+
y = np.arange(len(df))
1581+
1582+
model = xgb.XGBClassifier(enable_categorical=True)
1583+
model.fit(df, y)
1584+
1585+
model.apply(df) # this must not raise
1586+
1587+
model.set_params(enable_categorical=False)
1588+
with pytest.raises(ValueError, match="`enable_categorical`"):
1589+
model.apply(df)

0 commit comments

Comments
 (0)