Skip to content

Commit 5802d34

Browse files
authored
Correctly accept a pandas.CategoricalDtype(pandas.IntervalDtype(...), ...) type (#17604)
From an offline discussion, a pandas object with an `category[interval[...]]` type would be incorrectly be interpreted as a `category[struct[...]]` type. This can cause further problems with `cudf.pandas` as a `category[struct[...]]` type cannot be properly interpreted by pandas. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Bradley Dice (https://github.com/bdice) URL: #17604
1 parent e9e34e6 commit 5802d34

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

python/cudf/cudf/core/column/categorical.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -1095,17 +1095,22 @@ def as_categorical_column(self, dtype: Dtype) -> Self:
10951095
raise ValueError("dtype must be CategoricalDtype")
10961096

10971097
if not isinstance(self.categories, type(dtype.categories._column)):
1098-
# If both categories are of different Column types,
1099-
# return a column full of Nulls.
1100-
codes = cast(
1101-
cudf.core.column.numerical.NumericalColumn,
1102-
column.as_column(
1103-
_DEFAULT_CATEGORICAL_VALUE,
1104-
length=self.size,
1105-
dtype=self.codes.dtype,
1106-
),
1107-
)
1108-
codes = as_unsigned_codes(len(dtype.categories), codes)
1098+
if isinstance(
1099+
self.categories.dtype, cudf.StructDtype
1100+
) and isinstance(dtype.categories.dtype, cudf.IntervalDtype):
1101+
codes = self.codes
1102+
else:
1103+
# Otherwise if both categories are of different Column types,
1104+
# return a column full of nulls.
1105+
codes = cast(
1106+
cudf.core.column.numerical.NumericalColumn,
1107+
column.as_column(
1108+
_DEFAULT_CATEGORICAL_VALUE,
1109+
length=self.size,
1110+
dtype=self.codes.dtype,
1111+
),
1112+
)
1113+
codes = as_unsigned_codes(len(dtype.categories), codes)
11091114
return type(self)(
11101115
data=self.data, # type: ignore[arg-type]
11111116
size=self.size,

python/cudf/cudf/core/column/column.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -2076,18 +2076,26 @@ def as_column(
20762076
if isinstance(arbitrary.dtype, pd.DatetimeTZDtype):
20772077
new_tz = get_compatible_timezone(arbitrary.dtype)
20782078
arbitrary = arbitrary.astype(new_tz)
2079-
if isinstance(arbitrary.dtype, pd.CategoricalDtype) and isinstance(
2080-
arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype
2081-
):
2082-
new_tz = get_compatible_timezone(
2083-
arbitrary.dtype.categories.dtype
2084-
)
2085-
new_cats = arbitrary.dtype.categories.astype(new_tz)
2086-
new_dtype = pd.CategoricalDtype(
2087-
categories=new_cats, ordered=arbitrary.dtype.ordered
2088-
)
2089-
arbitrary = arbitrary.astype(new_dtype)
2090-
2079+
if isinstance(arbitrary.dtype, pd.CategoricalDtype):
2080+
if isinstance(
2081+
arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype
2082+
):
2083+
new_tz = get_compatible_timezone(
2084+
arbitrary.dtype.categories.dtype
2085+
)
2086+
new_cats = arbitrary.dtype.categories.astype(new_tz)
2087+
new_dtype = pd.CategoricalDtype(
2088+
categories=new_cats, ordered=arbitrary.dtype.ordered
2089+
)
2090+
arbitrary = arbitrary.astype(new_dtype)
2091+
elif (
2092+
isinstance(
2093+
arbitrary.dtype.categories.dtype, pd.IntervalDtype
2094+
)
2095+
and dtype is None
2096+
):
2097+
# Conversion to arrow converts IntervalDtype to StructDtype
2098+
dtype = cudf.CategoricalDtype.from_pandas(arbitrary.dtype)
20912099
return as_column(
20922100
pa.array(arbitrary, from_pandas=True),
20932101
nan_as_null=nan_as_null,

python/cudf/cudf/tests/test_categorical.py

+10
Original file line numberDiff line numberDiff line change
@@ -950,3 +950,13 @@ def test_index_set_categories(ordered):
950950
expected = pd_ci.set_categories([1, 2, 3, 4], ordered=ordered)
951951
result = cudf_ci.set_categories([1, 2, 3, 4], ordered=ordered)
952952
assert_eq(result, expected)
953+
954+
955+
def test_categorical_interval_pandas_roundtrip():
956+
expected = cudf.Series(cudf.interval_range(0, 5)).astype("category")
957+
result = cudf.Series.from_pandas(expected.to_pandas())
958+
assert_eq(result, expected)
959+
960+
expected = pd.Series(pd.interval_range(0, 5)).astype("category")
961+
result = cudf.Series.from_pandas(expected).to_pandas()
962+
assert_eq(result, expected)

0 commit comments

Comments
 (0)