Skip to content

Commit 8c2a45e

Browse files
authored
Removal IntervalDtype/StructDtype inheritance (#21114)
Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: #21114
1 parent eb3a55a commit 8c2a45e

File tree

8 files changed

+228
-39
lines changed

8 files changed

+228
-39
lines changed

python/cudf/cudf/core/column/column.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,9 +2646,9 @@ def column_empty(
26462646
dtype : Dtype
26472647
Type of the column.
26482648
"""
2649-
if (is_struct := isinstance(dtype, StructDtype)) or isinstance(
2650-
dtype, ListDtype
2651-
):
2649+
if (
2650+
is_struct := isinstance(dtype, (StructDtype, IntervalDtype))
2651+
) or isinstance(dtype, ListDtype):
26522652
if is_struct:
26532653
children = tuple(
26542654
column_empty(row_count, field_dtype)
@@ -2844,7 +2844,15 @@ def as_column(
28442844
dtype=dtype,
28452845
length=length,
28462846
)
2847-
if (
2847+
if isinstance(arbitrary.dtype, pd.IntervalDtype):
2848+
# Wrap StructColumn as IntervalColumn with proper metadata
2849+
result = result._with_type_metadata(
2850+
IntervalDtype(
2851+
subtype=arbitrary.dtype.subtype,
2852+
closed=arbitrary.dtype.closed,
2853+
)
2854+
)
2855+
elif (
28482856
cudf.get_option("mode.pandas_compatible")
28492857
and isinstance(arbitrary.dtype, pd.CategoricalDtype)
28502858
and is_pandas_nullable_extension_dtype(
@@ -3266,12 +3274,14 @@ def as_column(
32663274
length=length,
32673275
)
32683276
elif (
3269-
isinstance(element, (pd.Timestamp, pd.Timedelta))
3277+
isinstance(element, (pd.Timestamp, pd.Timedelta, pd.Interval))
32703278
or element is pd.NaT
32713279
):
32723280
# TODO: Remove this after
32733281
# https://github.com/apache/arrow/issues/26492
32743282
# is fixed.
3283+
# Note: pd.Interval also requires pandas Series conversion
3284+
# because PyArrow cannot infer interval type from raw list
32753285
return as_column(
32763286
pd.Series(arbitrary),
32773287
dtype=dtype,

python/cudf/cudf/core/column/interval.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77

88
import pandas as pd
99
import pyarrow as pa
10+
from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
1011
from typing_extensions import Self
1112

1213
import pylibcudf as plc
1314

1415
import cudf
1516
from cudf.core.column.column import ColumnBase, _handle_nulls, as_column
16-
from cudf.core.column.struct import StructColumn
1717
from cudf.core.dtypes import IntervalDtype, _dtype_to_metadata
1818
from cudf.utils.dtypes import is_dtype_obj_interval
19+
from cudf.utils.scalar import maybe_nested_pa_scalar_to_py
1920

2021
if TYPE_CHECKING:
2122
from cudf._typing import DtypeObj
23+
from cudf.core.buffer import Buffer
2224

2325

24-
class IntervalColumn(StructColumn):
26+
class IntervalColumn(ColumnBase):
27+
_VALID_PLC_TYPES = {plc.TypeId.STRUCT}
28+
2529
@classmethod
2630
def _validate_args( # type: ignore[override]
2731
cls, plc_column: plc.Column, dtype: IntervalDtype
@@ -48,6 +52,39 @@ def _validate_args( # type: ignore[override]
4852
raise ValueError("dtype must be a IntervalDtype.")
4953
return plc_column, dtype
5054

55+
def _with_type_metadata(self, dtype: DtypeObj) -> ColumnBase:
56+
"""
57+
Apply IntervalDtype metadata to this column.
58+
59+
Creates new children with the subtype metadata applied and
60+
reconstructs the plc.Column.
61+
"""
62+
if isinstance(dtype, IntervalDtype):
63+
new_children = tuple(
64+
ColumnBase.from_pylibcudf(child).astype(dtype.subtype)
65+
for child in self.plc_column.children()
66+
)
67+
new_plc_column = plc.Column(
68+
plc.DataType(plc.TypeId.STRUCT),
69+
self.plc_column.size(),
70+
self.plc_column.data(),
71+
self.plc_column.null_mask(),
72+
self.plc_column.null_count(),
73+
self.plc_column.offset(),
74+
[child.plc_column for child in new_children],
75+
)
76+
return type(self)._from_preprocessed(
77+
plc_column=new_plc_column,
78+
dtype=dtype,
79+
)
80+
# For pandas dtypes, store them directly in the column's dtype property
81+
elif isinstance(dtype, pd.ArrowDtype) and isinstance(
82+
dtype.pyarrow_dtype, ArrowIntervalType
83+
):
84+
self._dtype = dtype
85+
86+
return self
87+
5188
@classmethod
5289
def from_arrow(cls, array: pa.Array | pa.ChunkedArray) -> Self:
5390
if not isinstance(array, pa.ExtensionArray):
@@ -76,6 +113,36 @@ def to_arrow(self) -> pa.Array:
76113
struct_arrow = pa.array([], typ.storage_type)
77114
return pa.ExtensionArray.from_storage(typ, struct_arrow)
78115

116+
@classmethod
117+
def _deserialize_plc_column(
118+
cls,
119+
header: dict,
120+
dtype: DtypeObj,
121+
data: Buffer | None,
122+
mask: Buffer | None,
123+
children: list[plc.Column],
124+
) -> plc.Column:
125+
"""Construct plc.Column using STRUCT type for interval columns."""
126+
offset = header.get("offset", 0)
127+
if mask is None:
128+
null_count = 0
129+
else:
130+
null_count = plc.null_mask.null_count(
131+
mask, offset, header["size"] + offset
132+
)
133+
134+
plc_type = plc.DataType(plc.TypeId.STRUCT)
135+
return plc.Column(
136+
plc_type,
137+
header["size"],
138+
data,
139+
mask,
140+
null_count,
141+
offset,
142+
children,
143+
validate=False,
144+
)
145+
79146
def copy(self, deep: bool = True) -> Self:
80147
return super().copy(deep=deep)._with_type_metadata(self.dtype) # type: ignore[return-value]
81148

@@ -134,6 +201,12 @@ def right(self) -> ColumnBase:
134201
self.plc_column.children()[1]
135202
)._with_type_metadata(self.dtype.subtype) # type: ignore[union-attr]
136203

204+
@property
205+
def __cuda_array_interface__(self) -> dict[str, Any]:
206+
raise NotImplementedError(
207+
"Intervals are not yet supported via `__cuda_array_interface__`"
208+
)
209+
137210
def overlaps(other) -> ColumnBase:
138211
raise NotImplementedError("overlaps is not currently implemented.")
139212

@@ -176,6 +249,9 @@ def element_indexing(
176249
self, index: int
177250
) -> pd.Interval | dict[Any, Any] | None:
178251
result = super().element_indexing(index)
252+
if isinstance(result, pa.Scalar):
253+
py_element = maybe_nested_pa_scalar_to_py(result)
254+
result = self.dtype._recursively_replace_fields(py_element) # type: ignore[union-attr]
179255
if isinstance(result, dict) and cudf.get_option(
180256
"mode.pandas_compatible"
181257
):

python/cudf/cudf/core/column/struct.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,14 @@ def _validate_args( # type: ignore[override]
5353
cls, plc_column: plc.Column, dtype: StructDtype
5454
) -> tuple[plc.Column, StructDtype]:
5555
plc_column, dtype = super()._validate_args(plc_column, dtype) # type: ignore[assignment]
56-
# IntervalDtype is a subclass of StructDtype, so compare types exactly
5756
if (
5857
not cudf.get_option("mode.pandas_compatible")
59-
and type(dtype) is not StructDtype
58+
and not isinstance(dtype, StructDtype)
6059
) or (
6160
cudf.get_option("mode.pandas_compatible")
6261
and not is_dtype_obj_struct(dtype)
6362
):
64-
raise ValueError(
65-
f"{type(dtype).__name__} must be a StructDtype exactly."
66-
)
63+
raise ValueError(f"{type(dtype).__name__} must be a StructDtype.")
6764
return plc_column, dtype
6865

6966
def _get_sliced_child(self, idx: int) -> ColumnBase:
@@ -148,18 +145,38 @@ def __cuda_array_interface__(self) -> Mapping[str, Any]:
148145
"Structs are not yet supported via `__cuda_array_interface__`"
149146
)
150147

151-
def _with_type_metadata(
152-
self: StructColumn, dtype: DtypeObj
153-
) -> StructColumn:
154-
from cudf.core.column import IntervalColumn
148+
def _with_type_metadata(self: StructColumn, dtype: DtypeObj) -> ColumnBase:
155149
from cudf.core.dtypes import IntervalDtype
156150

157151
# Check IntervalDtype first because it's a subclass of StructDtype
158152
if isinstance(dtype, IntervalDtype):
159-
# TODO: Rewrite this to avoid needing to round-trip via ColumnBase
153+
# Dispatch to IntervalColumn when given IntervalDtype
154+
from cudf.core.column.interval import IntervalColumn
155+
156+
# Determine the current subtype from the first child
157+
first_child = ColumnBase.from_pylibcudf(
158+
self.plc_column.children()[0]
159+
)
160+
current_dtype = IntervalDtype(
161+
subtype=first_child.dtype, closed=dtype.closed
162+
)
163+
164+
# Convert to IntervalColumn and apply target metadata
165+
interval_col = IntervalColumn._from_preprocessed(
166+
plc_column=self.plc_column,
167+
dtype=current_dtype,
168+
)
169+
return interval_col._with_type_metadata(dtype)
170+
elif isinstance(dtype, StructDtype):
160171
new_children = tuple(
161-
ColumnBase.from_pylibcudf(child).astype(dtype.subtype)
162-
for child in self.plc_column.children()
172+
ColumnBase.from_pylibcudf(child)._with_type_metadata(
173+
dtype.fields[f]
174+
)
175+
for child, f in zip(
176+
self.plc_column.children(),
177+
dtype.fields.keys(),
178+
strict=True,
179+
)
163180
)
164181
new_plc_column = plc.Column(
165182
plc.DataType(plc.TypeId.STRUCT),
@@ -170,12 +187,10 @@ def _with_type_metadata(
170187
self.plc_column.offset(),
171188
[child.plc_column for child in new_children],
172189
)
173-
return IntervalColumn._from_preprocessed(
190+
return StructColumn._from_preprocessed(
174191
plc_column=new_plc_column,
175192
dtype=dtype,
176193
)
177-
elif isinstance(dtype, StructDtype):
178-
self._dtype = dtype
179194
# For pandas dtypes, store them directly in the column's dtype property
180195
elif isinstance(dtype, pd.ArrowDtype) and isinstance(
181196
dtype.pyarrow_dtype, pa.StructType

0 commit comments

Comments
 (0)