Skip to content

Commit 45def5c

Browse files
authored
Support mixed metrics for the skl interface. (#11536)
1 parent eabb5ed commit 45def5c

File tree

3 files changed

+83
-7
lines changed

3 files changed

+83
-7
lines changed

python-package/xgboost/sklearn.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines
22
"""Scikit-Learn Wrapper interface for XGBoost."""
3+
import collections
34
import copy
45
import json
56
import os
@@ -432,7 +433,7 @@ def task(i: int) -> float:
432433
- ``one_output_per_tree``: One model for each target.
433434
- ``multi_output_tree``: Use multi-target trees.
434435
435-
eval_metric : {Optional[Union[str, List[str], Callable]]}
436+
eval_metric : {Optional[Union[str, List[Union[str, Callable]], Callable]]}
436437
437438
.. versionadded:: 1.6.0
438439
@@ -763,7 +764,7 @@ def __init__(
763764
max_cat_to_onehot: Optional[int] = None,
764765
max_cat_threshold: Optional[int] = None,
765766
multi_strategy: Optional[str] = None,
766-
eval_metric: Optional[Union[str, List[str], Callable]] = None,
767+
eval_metric: Optional[Union[str, List[Union[str, Callable]], Callable]] = None,
767768
early_stopping_rounds: Optional[int] = None,
768769
callbacks: Optional[List[TrainingCallback]] = None,
769770
**kwargs: Any,
@@ -1103,14 +1104,42 @@ def _duplicated(parameter: str) -> None:
11031104

11041105
# - configure callable evaluation metric
11051106
metric: Optional[Metric] = None
1107+
1108+
def custom_metric(m: Callable) -> Metric:
1109+
if self._get_type() == "ranker":
1110+
wrapped = ltr_metric_decorator(m, self.n_jobs)
1111+
else:
1112+
wrapped = _metric_decorator(m)
1113+
return wrapped
1114+
1115+
def invalid_type(m: Any) -> None:
1116+
msg = f"Invalid type for the `eval_metric`: {type(m)}"
1117+
raise TypeError(msg)
1118+
11061119
if self.eval_metric is not None:
11071120
if callable(self.eval_metric):
1108-
if self._get_type() == "ranker":
1109-
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
1110-
else:
1111-
metric = _metric_decorator(self.eval_metric)
1112-
else:
1121+
metric = custom_metric(self.eval_metric)
1122+
elif isinstance(self.eval_metric, str):
11131123
params.update({"eval_metric": self.eval_metric})
1124+
else:
1125+
# A sequence of metrics
1126+
if not isinstance(self.eval_metric, collections.abc.Sequence):
1127+
invalid_type(self.eval_metric)
1128+
# Could be a list of strings or callables
1129+
builtin_metrics: List[str] = []
1130+
for m in self.eval_metric:
1131+
if callable(m):
1132+
if metric is not None:
1133+
raise NotImplementedError(
1134+
"Using multiple custom metrics is not yet supported."
1135+
)
1136+
metric = custom_metric(m)
1137+
elif isinstance(m, str):
1138+
builtin_metrics.append(m)
1139+
else:
1140+
invalid_type(m)
1141+
if builtin_metrics:
1142+
params.update({"eval_metric": builtin_metrics})
11141143

11151144
if feature_weights is not None:
11161145
_deprecated("feature_weights")

tests/python/test_with_sklearn.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,29 @@ def merror(y_true: np.ndarray, predt: np.ndarray):
13971397
clf.fit(X, y, eval_set=[(X, y)])
13981398

13991399

1400+
def test_mixed_metrics() -> None:
1401+
from sklearn.datasets import make_classification
1402+
from sklearn.metrics import hamming_loss, hinge_loss, log_loss
1403+
1404+
X, y = make_classification(random_state=2025)
1405+
1406+
clf = xgb.XGBClassifier(eval_metric=["logloss", hinge_loss], n_estimators=2)
1407+
clf.fit(X, y, eval_set=[(X, y)])
1408+
results = clf.evals_result()["validation_0"]
1409+
assert "logloss" in results
1410+
assert "hinge_loss" in results
1411+
1412+
clf = xgb.XGBClassifier(eval_metric=[hamming_loss, log_loss], n_estimators=2)
1413+
with pytest.raises(
1414+
NotImplementedError, match="multiple custom metrics is not yet supported."
1415+
):
1416+
clf.fit(X, y, eval_set=[(X, y)])
1417+
1418+
clf = xgb.XGBClassifier(eval_metric=[123, log_loss], n_estimators=2)
1419+
with pytest.raises(TypeError, match="Invalid type for the `eval_metric`"):
1420+
clf.fit(X, y, eval_set=[(X, y)])
1421+
1422+
14001423
def test_weighted_evaluation_metric():
14011424
from sklearn.datasets import make_hastie_10_2
14021425
from sklearn.metrics import log_loss

tests/test_distributed/test_with_dask/test_with_dask.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,30 @@ def sqr(
16771677
results_custom = reg.evals_result()
16781678
tm.non_increasing(results_custom["validation_0"]["rmse"])
16791679

1680+
@pytest.mark.skipif(**tm.no_sklearn())
1681+
def test_custom_metrics(self, client: "Client") -> None:
1682+
from sklearn.datasets import make_classification
1683+
from sklearn.metrics import hamming_loss, hinge_loss, log_loss
1684+
1685+
Xn, yn = make_classification(random_state=2025)
1686+
X, y = da.array(Xn), da.array(yn)
1687+
1688+
clf = dxgb.DaskXGBClassifier(
1689+
eval_metric=["logloss", hinge_loss], n_estimators=2
1690+
)
1691+
clf.fit(X, y, eval_set=[(X, y)])
1692+
results = clf.evals_result()["validation_0"]
1693+
assert "logloss" in results
1694+
assert "hinge_loss" in results
1695+
1696+
clf = dxgb.DaskXGBClassifier(
1697+
eval_metric=[hamming_loss, log_loss], n_estimators=2
1698+
)
1699+
with pytest.raises(
1700+
NotImplementedError, match="multiple custom metrics is not yet supported."
1701+
):
1702+
clf.fit(X, y, eval_set=[(X, y)])
1703+
16801704
def test_no_duplicated_partition(self) -> None:
16811705
"""Assert each worker has the correct amount of data, and DMatrix initialization
16821706
doesn't generate unnecessary copies of data.

0 commit comments

Comments
 (0)