Skip to content

Commit 92e2c97

Browse files
authored
[enc] Add a cat accessor to the booster. (#11568)
1 parent 34ed70e commit 92e2c97

File tree

8 files changed

+231
-117
lines changed

8 files changed

+231
-117
lines changed

include/xgboost/gbm.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace xgboost {
2424
class Json;
2525
class FeatureMap;
2626
class ObjFunction;
27+
class CatContainer;
2728

2829
struct Context;
2930
struct LearnerModelParam;
@@ -135,12 +136,12 @@ class GradientBooster : public Model, public Configurable {
135136
bst_layer_t layer_begin, bst_layer_t layer_end,
136137
bool approximate) = 0;
137138

138-
/*!
139-
* \brief dump the model in the requested format
140-
* \param fmap feature map that may help give interpretations of feature
141-
* \param with_stats extra statistics while dumping model
142-
* \param format the format to dump the model in
143-
* \return a vector of dump for boosters.
139+
/**
140+
* @brief dump the model in the requested format
141+
* @param fmap feature map that may help give interpretations of feature
142+
* @param with_stats extra statistics while dumping model
143+
* @param format the format to dump the model in
144+
* @return a vector of dump for boosters.
144145
*/
145146
[[nodiscard]] virtual std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
146147
std::string format) const = 0;
@@ -149,12 +150,19 @@ class GradientBooster : public Model, public Configurable {
149150
common::Span<int32_t const> trees,
150151
std::vector<bst_feature_t>* features,
151152
std::vector<float>* scores) const = 0;
152-
/*!
153-
* \brief create a gradient booster from given name
154-
* \param name name of gradient booster
155-
* \param generic_param Pointer to runtime parameters
156-
* \param learner_model_param pointer to global model parameters
157-
* \return The created booster.
153+
/**
154+
* @brief Getter for categories.
155+
*/
156+
[[nodiscard]] virtual CatContainer const* Cats() const {
157+
LOG(FATAL) << "Retrieving categories is not supported by the current booster.";
158+
return nullptr;
159+
}
160+
/**
161+
* @brief create a gradient booster from given name
162+
* @param name name of gradient booster
163+
* @param generic_param Pointer to runtime parameters
164+
* @param learner_model_param pointer to global model parameters
165+
* @return The created booster.
158166
*/
159167
static GradientBooster* Create(const std::string& name, Context const* ctx,
160168
LearnerModelParam const* learner_model_param);

include/xgboost/learner.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* Copyright 2015-2025, XGBoost Contributors
3-
* \file learner.h
3+
*
44
* \brief Learner interface that integrates objective, gbm and evaluation together.
55
* This is the user facing XGBoost training module.
66
* \author Tianqi Chen
@@ -35,6 +35,7 @@ class Json;
3535
struct XGBAPIThreadLocalEntry;
3636
template <typename T>
3737
class HostDeviceVector;
38+
class CatContainer;
3839

3940
enum class PredictionType : std::uint8_t { // NOLINT
4041
kValue = 0,
@@ -167,11 +168,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
167168
*/
168169
virtual void SetParam(const std::string& key, const std::string& value) = 0;
169170

170-
/*!
171-
* \brief Get the number of features of the booster.
172-
* \return number of features
171+
/**
172+
* @brief Get the number of features of the booster.
173+
* @return The number of features
173174
*/
174-
virtual uint32_t GetNumFeature() const = 0;
175+
virtual bst_feature_t GetNumFeature() const = 0;
175176

176177
/*!
177178
* \brief Set additional attribute to the Booster.
@@ -221,16 +222,19 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
221222
* \param fn Output feature types
222223
*/
223224
virtual void GetFeatureTypes(std::vector<std::string>* ft) const = 0;
224-
225225
/**
226-
* \brief Slice the model.
226+
* @brief Getter for categories.
227+
*/
228+
[[nodiscard]] virtual CatContainer const* Cats() const = 0;
229+
/**
230+
* @brief Slice the model.
227231
*
228232
* See InplacePredict for layer parameters.
229233
*
230-
* \param step step size between slice.
231-
* \param out_of_bound Return true if end layer is out of bound.
234+
* @param step step size between slice.
235+
* @param out_of_bound Return true if end layer is out of bound.
232236
*
233-
* \return a sliced model.
237+
* @return a sliced model.
234238
*/
235239
virtual Learner* Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step,
236240
bool* out_of_bound) = 0;

python-package/xgboost/core.py

Lines changed: 80 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,64 @@ def inner_f(*args: Any, **kwargs: Any) -> _T:
779779
_deprecate_positional_args = require_keyword_args(False)
780780

781781

782+
def _get_categories(
783+
cfn: Callable[[ctypes.c_char_p], int],
784+
feature_names: Optional[FeatureNames],
785+
n_features: int,
786+
) -> Optional[Dict[str, "pa.DictionaryArray"]]:
787+
if not is_pyarrow_available():
788+
raise ImportError("`pyarrow` is required for exporting categories.")
789+
790+
if TYPE_CHECKING:
791+
import pyarrow as pa
792+
else:
793+
pa = import_pyarrow()
794+
795+
fnames = feature_names
796+
if fnames is None:
797+
fnames = [str(i) for i in range(n_features)]
798+
799+
results: Dict[str, "pa.DictionaryArray"] = {}
800+
801+
ret = ctypes.c_char_p()
802+
_check_call(cfn(ret))
803+
if ret.value is None:
804+
return None
805+
806+
retstr = ret.value.decode() # pylint: disable=no-member
807+
jcats = json.loads(retstr)
808+
assert isinstance(jcats, list) and len(jcats) == n_features
809+
810+
for fidx in range(n_features):
811+
f_jcats = jcats[fidx]
812+
if f_jcats is None:
813+
# Numeric data
814+
results[fnames[fidx]] = None
815+
continue
816+
817+
if "offsets" not in f_jcats:
818+
values = from_array_interface(f_jcats)
819+
pa_values = pa.Array.from_pandas(values)
820+
results[fnames[fidx]] = pa_values
821+
continue
822+
823+
joffsets = f_jcats["offsets"]
824+
jvalues = f_jcats["values"]
825+
offsets = from_array_interface(joffsets, True)
826+
values = from_array_interface(jvalues, True)
827+
pa_offsets = pa.array(offsets).buffers()
828+
pa_values = pa.array(values).buffers()
829+
assert (
830+
pa_offsets[0] is None and pa_values[0] is None
831+
), "Should not have null mask."
832+
pa_dict = pa.StringArray.from_buffers(
833+
len(offsets) - 1, pa_offsets[1], pa_values[1]
834+
)
835+
results[fnames[fidx]] = pa_dict
836+
837+
return results
838+
839+
782840
@unique
783841
class DataSplitMode(IntEnum):
784842
"""Supported data split mode for DMatrix."""
@@ -1299,58 +1357,11 @@ def get_categories(self) -> Optional[Dict[str, "pa.DictionaryArray"]]:
12991357
.. versionadded:: 3.1.0
13001358
13011359
"""
1302-
if not is_pyarrow_available():
1303-
raise ImportError("`pyarrow` is required for exporting categories.")
1304-
1305-
if TYPE_CHECKING:
1306-
import pyarrow as pa
1307-
else:
1308-
pa = import_pyarrow()
1309-
1310-
n_features = self.num_col()
1311-
fnames = self.feature_names
1312-
if fnames is None:
1313-
fnames = [str(i) for i in range(n_features)]
1314-
1315-
results: Dict[str, "pa.DictionaryArray"] = {}
1316-
1317-
ret = ctypes.c_char_p()
1318-
_check_call(_LIB.XGBDMatrixGetCategories(self.handle, ctypes.byref(ret)))
1319-
if ret.value is None:
1320-
return None
1321-
1322-
retstr = ret.value.decode() # pylint: disable=no-member
1323-
jcats = json.loads(retstr)
1324-
assert isinstance(jcats, list) and len(jcats) == n_features
1325-
1326-
for fidx in range(n_features):
1327-
f_jcats = jcats[fidx]
1328-
if f_jcats is None:
1329-
# Numeric data
1330-
results[fnames[fidx]] = None
1331-
continue
1332-
1333-
if "offsets" not in f_jcats:
1334-
values = from_array_interface(f_jcats)
1335-
pa_values = pa.Array.from_pandas(values)
1336-
results[fnames[fidx]] = pa_values
1337-
continue
1338-
1339-
joffsets = f_jcats["offsets"]
1340-
jvalues = f_jcats["values"]
1341-
offsets = from_array_interface(joffsets, True)
1342-
values = from_array_interface(jvalues, True)
1343-
pa_offsets = pa.array(offsets).buffers()
1344-
pa_values = pa.array(values).buffers()
1345-
assert (
1346-
pa_offsets[0] is None and pa_values[0] is None
1347-
), "Should not have null mask."
1348-
pa_dict = pa.StringArray.from_buffers(
1349-
len(offsets) - 1, pa_offsets[1], pa_values[1]
1350-
)
1351-
results[fnames[fidx]] = pa_dict
1352-
1353-
return results
1360+
return _get_categories(
1361+
lambda ret: _LIB.XGBDMatrixGetCategories(self.handle, ctypes.byref(ret)),
1362+
self.feature_names,
1363+
self.num_col(),
1364+
)
13541365

13551366
def num_row(self) -> int:
13561367
"""Get the number of rows in the DMatrix."""
@@ -2312,6 +2323,23 @@ def feature_names(self) -> Optional[FeatureNames]:
23122323
def feature_names(self, features: Optional[FeatureNames]) -> None:
23132324
self._set_feature_info(features, "feature_name")
23142325

2326+
def get_categories(self) -> Optional[Dict[str, "pa.DictionaryArray"]]:
2327+
"""Get the categories in the dataset using `pyarrow`. Returns `None` if there's
2328+
no categorical features.
2329+
2330+
.. warning::
2331+
2332+
This function is still working in progress.
2333+
2334+
.. versionadded:: 3.1.0
2335+
2336+
"""
2337+
return _get_categories(
2338+
lambda ret: _LIB.XGBoosterGetCategories(self.handle, ctypes.byref(ret)),
2339+
self.feature_names,
2340+
self.num_features(),
2341+
)
2342+
23152343
def set_param(
23162344
self,
23172345
params: Union[Dict, Iterable[Tuple[str, Any]], str],

python-package/xgboost/testing/ordinal.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from ..core import DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix
1515
from ..data import _lazy_load_cudf_is_cat
1616
from ..training import train
17-
from .data import IteratorForTest, is_pd_cat_dtype, make_categorical
17+
from .data import (
18+
IteratorForTest,
19+
is_pd_cat_dtype,
20+
make_batches,
21+
make_categorical,
22+
)
1823

1924

2025
def get_df_impl(device: str) -> Tuple[Type, Type]:
@@ -50,10 +55,24 @@ def assert_allclose(device: str, a: Any, b: Any) -> None:
5055
cp.testing.assert_allclose(a, b)
5156

5257

58+
def comp_booster(device: Literal["cpu", "cuda"], Xy: DMatrix, booster: str) -> None:
59+
"""Compare the results from DMatrix and Booster."""
60+
cats = Xy.get_categories()
61+
assert cats is not None
62+
63+
rng = np.random.default_rng(2025)
64+
Xy.set_label(rng.normal(size=Xy.num_row()))
65+
bst = train({"booster": booster, "device": device}, Xy, 1)
66+
cats_bst = bst.get_categories()
67+
assert cats_bst is not None
68+
for k, v in cats_bst.items():
69+
assert v == cats[k]
70+
71+
5372
def run_cat_container(device: Literal["cpu", "cuda"]) -> None:
5473
"""Basic tests for the container class used by the DMatrix."""
5574

56-
def run_dispatch(device: str, DMatrixT: Type) -> None:
75+
def run_dispatch(device: Literal["cpu", "cuda"], DMatrixT: Type) -> None:
5776
Df, _ = get_df_impl(device)
5877
# Basic test with a single feature
5978
df = Df({"c": ["cdef", "abc"]}, dtype="category")
@@ -86,10 +105,16 @@ def run_dispatch(device: str, DMatrixT: Type) -> None:
86105
assert_allclose(device, csr.indptr, np.array([0, 1, 1, 2, 3]))
87106
assert_allclose(device, csr.indices, np.array([0, 0, 0]))
88107

108+
comp_booster(device, Xy, "gbtree")
109+
comp_booster(device, Xy, "dart")
110+
89111
# Test with explicit null-terminated strings.
90112
df = Df({"c": ["cdef", None, "abc", "abc\0"]}, dtype="category")
91113
Xy = DMatrixT(df, enable_categorical=True)
92114

115+
comp_booster(device, Xy, "gbtree")
116+
comp_booster(device, Xy, "dart")
117+
93118
for dm in (DMatrix, QuantileDMatrix):
94119
run_dispatch(device, dm)
95120

@@ -129,6 +154,7 @@ def check(Xy: DMatrix, X: pd.DataFrame) -> None:
129154
assert cats[fname] is None
130155

131156
if not hasattr(Xy, "ref"): # not quantile DMatrix.
157+
assert not isinstance(Xy, QuantileDMatrix)
132158
with tempfile.TemporaryDirectory() as tmpdir:
133159
fname = os.path.join(tmpdir, "DMatrix.binary")
134160
Xy.save_binary(fname)
@@ -144,6 +170,9 @@ def check(Xy: DMatrix, X: pd.DataFrame) -> None:
144170
else:
145171
assert v_0.to_pylist() == v_1.to_pylist()
146172

173+
comp_booster(device, Xy, "gbtree")
174+
comp_booster(device, Xy, "dart")
175+
147176
def run_dispatch(DMatrixT: Type) -> None:
148177
# full str type
149178
X, y = make_categorical(
@@ -216,6 +245,15 @@ def run_dispatch(DMatrixT: Type) -> None:
216245
for dm in (DMatrix, QuantileDMatrix):
217246
run_dispatch(dm)
218247

248+
batches = make_batches(
249+
n_samples_per_batch=128, n_features=4, n_batches=1, use_cupy=device == "cuda"
250+
)
251+
X, y, w = map(lambda x: x[0], batches)
252+
Xy = DMatrix(X, y, weight=w)
253+
assert Xy.get_categories() is None
254+
Xy = QuantileDMatrix(X, y, weight=w)
255+
assert Xy.get_categories() is None
256+
219257

220258
def run_cat_container_iter(device: Literal["cpu", "cuda"]) -> None:
221259
"""Test the categories container for iterator-based inputs."""

0 commit comments

Comments
 (0)