Skip to content

Support categorical features from polars. #11565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 79 additions & 22 deletions python-package/xgboost/_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,58 @@ def npstr_to_arrow_strarr(strarr: np.ndarray) -> Tuple[np.ndarray, str]:
return offsets.astype(np.int32), values


@functools.cache
def _arrow_npdtype() -> Dict[Any, Type[np.number]]:
import pyarrow as pa

mapping: Dict[Any, Type[np.number]] = {
pa.int8(): np.int8,
pa.int16(): np.int16,
pa.int32(): np.int32,
pa.int64(): np.int64,
pa.uint8(): np.uint8,
pa.uint16(): np.uint16,
pa.uint32(): np.uint32,
pa.uint64(): np.uint64,
pa.float16(): np.float16,
pa.float32(): np.float32,
pa.float64(): np.float64,
}

return mapping


def _arrow_mask_inf(mask: Optional["pa.Buffer"], size: int) -> Optional[ArrayInf]:
if mask is not None:
jmask: Optional[ArrayInf] = {
"data": (mask.address, True),
"typestr": "<t1",
"version": 3,
"strides": None,
"shape": (size,),
"mask": None,
}
if not mask.is_cpu:
jmask["stream"] = STREAM_PER_THREAD # type: ignore
else:
jmask = None
return jmask


def _arrow_buf_inf(buf: "pa.Buffer", typestr: str, size: int) -> ArrayInf:
jdata: ArrayInf = {
"data": (buf.address, True),
"typestr": typestr,
"version": 3,
"strides": None,
"shape": (size,),
"mask": None,
}
if not buf.is_cpu:
jdata["stream"] = STREAM_PER_THREAD # type: ignore
return jdata


def _arrow_cat_inf( # pylint: disable=too-many-locals
cats: "pa.StringArray",
codes: Union[_ArrayLikeArg, _CudaArrayLikeArg, "pa.IntegerArray"],
Expand All @@ -254,29 +306,24 @@ def _arrow_cat_inf( # pylint: disable=too-many-locals
assert offset.is_cpu

off_len = len(cats) + 1
if offset.size != off_len * (np.iinfo(np.int32).bits / 8):
raise TypeError("Arrow dictionary type offsets is required to be 32 bit.")

joffset: ArrayInf = {
"data": (offset.address, True),
"typestr": "<i4",
"version": 3,
"strides": None,
"shape": (off_len,),
"mask": None,
}
def get_n_bytes(typ: Type) -> int:
return off_len * (np.iinfo(typ).bits // 8)

def make_buf_inf(buf: pa.Buffer, typestr: str) -> ArrayInf:
return {
"data": (buf.address, True),
"typestr": typestr,
"version": 3,
"strides": None,
"shape": (buf.size,),
"mask": None,
}
if offset.size == get_n_bytes(np.int64):
# Convert to 32bit integer, arrow recommends against the use of i64. Also,
# XGBoost cannot handle large number of categories (> 2**31).
assert isinstance(cats, pa.LargeStringArray), type(cats)
i32cats = pa.Array.from_pandas(cats.to_numpy(zero_copy_only=False))
mask, offset, data = i32cats.buffers()

if offset.size != get_n_bytes(np.int32):
raise TypeError(
"Arrow dictionary type offsets is required to be 32-bit integer."
)

jdata = make_buf_inf(data, "<i1")
joffset = _arrow_buf_inf(offset, "<i4", off_len)
jdata = _arrow_buf_inf(data, "|i1", data.size)
# Categories should not have missing values.
assert mask is None

Expand All @@ -290,9 +337,19 @@ def make_array_inf(
if hasattr(array, "__cuda_array_interface__"):
inf = cuda_array_interface_dict(array)
return inf, None
if isinstance(array, pa.Array):
mask, data = array.buffers()
jdata = make_array_interface(
data.address,
shape=(len(array),),
dtype=_arrow_npdtype()[array.type],
is_cuda=not data.is_cpu,
)
jdata["mask"] = _arrow_mask_inf(mask, len(array))
return jdata, None

# Other types (like arrow itself) are not yet supported.
raise TypeError("Invalid input type.")
# Other types are not yet supported.
raise TypeError(f"Invalid input type: {type(array)}")

cats_tmp = (mask, offset, data)
jcodes, codes_tmp = make_array_inf(codes)
Expand Down
118 changes: 54 additions & 64 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Optional,
Sequence,
Tuple,
Type,
TypeAlias,
TypeGuard,
Union,
)
Expand All @@ -27,6 +27,9 @@
DfCatAccessor,
StringArray,
TransformedDf,
_arrow_cat_inf,
_arrow_mask_inf,
_arrow_npdtype,
_ensure_np_dtype,
_is_df_cat,
array_hasobject,
Expand Down Expand Up @@ -768,25 +771,16 @@ def _from_pandas_series(
)


@functools.cache
def _arrow_npdtype() -> Dict[Any, Type[np.number]]:
import pyarrow as pa

mapping: Dict[Any, Type[np.number]] = {
pa.int8(): np.int8,
pa.int16(): np.int16,
pa.int32(): np.int32,
pa.int64(): np.int64,
pa.uint8(): np.uint8,
pa.uint16(): np.uint16,
pa.uint32(): np.uint32,
pa.uint64(): np.uint64,
pa.float16(): np.float16,
pa.float32(): np.float32,
pa.float64(): np.float64,
}

return mapping
# Type for storing JSON-encoded array interface
AifType: TypeAlias = List[
Union[
ArrayInf, # numeric column
Tuple[ # categorical column
Union[ArrayInf, StringArray], # string index, numeric index
ArrayInf, # codes
],
]
]


class ArrowTransformed(TransformedDf):
Expand All @@ -797,45 +791,49 @@ def __init__(
) -> None:
self.columns = columns

def array_interface(self) -> bytes:
"""Return a byte string for JSON encoded array interface."""
self.temporary_buffers: List[Tuple] = []

if TYPE_CHECKING:
import pyarrow as pa
else:
pa = import_pyarrow()

def array_inf(col: Union["pa.NumericArray", "pa.DictionaryArray"]) -> ArrayInf:
buffers = col.buffers()
aitfs: AifType = []

def push_series(col: Union["pa.NumericArray", "pa.DictionaryArray"]) -> None:
if isinstance(col, pa.DictionaryArray):
mask, _, data = col.buffers()
cats = col.dictionary
codes = col.indices
if not isinstance(cats, (pa.StringArray, pa.LargeStringArray)):
raise TypeError(
"Only string-based categorical index is supported for arrow."
)
jnames, jcodes, buf = _arrow_cat_inf(cats, codes)
self.temporary_buffers.append(buf)
aitfs.append((jnames, jcodes))
else:
mask, data = buffers
mask, data = col.buffers()

assert data.is_cpu
assert col.offset == 0
assert data.is_cpu
assert col.offset == 0

jdata = make_array_interface(
data.address,
shape=(len(col),),
dtype=_arrow_npdtype()[col.type],
is_cuda=not data.is_cpu,
)
if mask is not None:
jmask: ArrayInf = {
"data": (mask.address, True),
"typestr": "<t1",
"version": 3,
"strides": None,
"shape": (len(col),),
"mask": None,
}
if not data.is_cpu:
jmask["stream"] = 2 # type: ignore
jdata["mask"] = jmask
return jdata

arrays = list(map(array_inf, self.columns))
sarrays = bytes(json.dumps(arrays), "utf-8")
jdata = make_array_interface(
data.address,
shape=(len(col),),
dtype=_arrow_npdtype()[col.type],
is_cuda=not data.is_cpu,
)
jdata["mask"] = _arrow_mask_inf(mask, len(col))
aitfs.append(jdata)

for col in self.columns:
push_series(col)

self.aitfs = aitfs

def array_interface(self) -> bytes:
"""Return a byte string for JSON encoded array interface."""
sarrays = bytes(json.dumps(self.aitfs), "utf-8")
return sarrays

@property
Expand All @@ -850,7 +848,7 @@ def _is_arrow(data: DataType) -> bool:

def _transform_arrow_table(
data: "pa.Table",
_: bool, # not used yet, enable_categorical
enable_categorical: bool,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
) -> Tuple[ArrowTransformed, Optional[FeatureNames], Optional[FeatureTypes]]:
Expand All @@ -872,6 +870,10 @@ def _transform_arrow_table(
col: Union["pa.NumericArray", "pa.DictionaryArray"] = col0.combine_chunks()
if isinstance(col, pa.BooleanArray):
col = col.cast(pa.int8()) # bit-compressed array, not supported.
if is_arrow_dict(col) and not enable_categorical:
# None because the function doesn't know how to get the type info from arrow
# table.
_invalid_dataframe_dtype(None)
columns.append(col)

df_t = ArrowTransformed(columns)
Expand Down Expand Up @@ -937,10 +939,6 @@ def _arrow_feature_info(data: DataType) -> Tuple[List[str], List]:
def map_type(name: str) -> str:
col = table.column(name)
if isinstance(col.type, pa.DictionaryType):
raise NotImplementedError(
"Categorical feature is not yet supported with the current input data "
"type."
)
return CAT_T # pylint: disable=unreachable

return _arrow_dtype()[col.type]
Expand Down Expand Up @@ -1073,15 +1071,7 @@ def __init__(self, columns: List[Union["PdSeries", DfCatAccessor]]) -> None:
# the DMatrix or the booster.
self.temporary_buffers: List[Tuple] = []

aitfs: List[
Union[
ArrayInf, # numeric column
Tuple[ # categorical column
Union[ArrayInf, StringArray], # string index, numeric index
ArrayInf, # codes
],
]
] = []
aitfs: AifType = []

def push_series(ser: Any) -> None:
if _is_df_cat(ser):
Expand Down
2 changes: 1 addition & 1 deletion src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ template <bool allow_mask, typename CategoricalIndex>
auto const& jstr = get<Object const>(jnames.at("values"));
auto strbuf = ArrayInterface<1>(jstr);
CHECK_EQ(strbuf.type, ArrayInterfaceHandler::kI1);

CHECK_EQ(offset.type, ArrayInterfaceHandler::kI4);
auto names = enc::CatStrArrayView{
common::Span{static_cast<std::int32_t const*>(offset.data), offset.Shape<0>()},
common::Span<std::int8_t const>{reinterpret_cast<std::int8_t const*>(strbuf.data), strbuf.n}};
Expand Down
42 changes: 38 additions & 4 deletions tests/python/test_with_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,43 @@ def test_regressor() -> None:
def test_categorical() -> None:
import polars as pl

cats = ["aa", "cc", "bb", "ee", "ee"]
df = pl.DataFrame(
{"f0": [1, 2, 3], "b": ["a", "b", "c"]},
schema=[("a", pl.Int64()), ("b", pl.Categorical())],
{"f0": [1, 3, 2, 4, 4], "f1": cats},
schema=[("f0", pl.Int64()), ("f1", pl.Categorical(ordering="lexical"))],
)
with pytest.raises(NotImplementedError, match="Categorical feature"):
xgb.DMatrix(df, enable_categorical=True)
with pytest.raises(ValueError, match="enable_categorical"):
xgb.DMatrix(df)

data = xgb.DMatrix(df, enable_categorical=True)
categories = data.get_categories()
assert categories is not None
assert categories["f0"] is None
assert categories["f1"].to_pylist() == cats[:4]

df = pl.DataFrame(
{"f0": [1, 3, 2, 4, 4], "f1": cats},
schema=[("f0", pl.Int64()), ("f1", pl.Enum(cats[:4]))],
)
data = xgb.DMatrix(df, enable_categorical=True)
categories = data.get_categories()
assert categories is not None
assert categories["f0"] is None
assert categories["f1"].to_pylist() == cats[:4]

rng = np.random.default_rng(2025)
y = rng.normal(size=(df.shape[0]))
Xy = xgb.QuantileDMatrix(df, y, enable_categorical=True)
booster = xgb.train({}, Xy, num_boost_round=8)
predt_0 = booster.inplace_predict(df)

df_rev = pl.DataFrame(
{"f0": [1, 3, 2, 4, 4], "f1": cats},
schema=[("f0", pl.Int64()), ("f1", pl.Enum(cats[:4][::-1]))],
)
predt_1 = booster.inplace_predict(df_rev)
assert (
df["f1"].cat.get_categories().to_list()
!= df_rev["f1"].cat.get_categories().to_list()
)
np.testing.assert_allclose(predt_0, predt_1)
Loading