-
Notifications
You must be signed in to change notification settings - Fork 1k
Removal IntervalDtype/StructDtype inheritance #21114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
0d54326
0357100
ebdc5ea
243ae71
c76e792
f3840c6
7c8fe07
35fd074
017623b
c8e20b4
e47a7aa
dfd6a8c
9261a64
add9078
34ed654
e87fd75
3c2b2aa
5786fb2
6571d76
8da1713
a17ad0b
b39b174
825d8f4
fec7b1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,15 +13,18 @@ | |
|
|
||
| import cudf | ||
| from cudf.core.column.column import ColumnBase, _handle_nulls, as_column | ||
| from cudf.core.column.struct import StructColumn | ||
| from cudf.core.dtypes import IntervalDtype, _dtype_to_metadata | ||
| from cudf.utils.dtypes import is_dtype_obj_interval | ||
| from cudf.utils.scalar import maybe_nested_pa_scalar_to_py | ||
|
|
||
| if TYPE_CHECKING: | ||
| from cudf._typing import DtypeObj | ||
| from cudf.core.buffer import Buffer | ||
|
|
||
|
|
||
| class IntervalColumn(StructColumn): | ||
| class IntervalColumn(ColumnBase): | ||
| _VALID_PLC_TYPES = {plc.TypeId.STRUCT} | ||
|
|
||
| @classmethod | ||
| def _validate_args( # type: ignore[override] | ||
| cls, plc_column: plc.Column, dtype: IntervalDtype | ||
|
|
@@ -48,6 +51,39 @@ def _validate_args( # type: ignore[override] | |
| raise ValueError("dtype must be a IntervalDtype.") | ||
| return plc_column, dtype | ||
|
|
||
| def _with_type_metadata(self, dtype: DtypeObj) -> ColumnBase: | ||
| """ | ||
| Apply IntervalDtype metadata to this column. | ||
|
|
||
| Creates new children with the subtype metadata applied and | ||
| reconstructs the plc.Column. | ||
| """ | ||
| if isinstance(dtype, IntervalDtype): | ||
| new_children = tuple( | ||
| ColumnBase.from_pylibcudf(child).astype(dtype.subtype) | ||
| for child in self.plc_column.children() | ||
| ) | ||
| new_plc_column = plc.Column( | ||
| plc.DataType(plc.TypeId.STRUCT), | ||
| self.plc_column.size(), | ||
| self.plc_column.data(), | ||
| self.plc_column.null_mask(), | ||
| self.plc_column.null_count(), | ||
| self.plc_column.offset(), | ||
| [child.plc_column for child in new_children], | ||
| ) | ||
| return type(self)._from_preprocessed( | ||
| plc_column=new_plc_column, | ||
| dtype=dtype, | ||
| ) | ||
| # For pandas dtypes, store them directly in the column's dtype property | ||
| elif isinstance(dtype, pd.ArrowDtype) and isinstance( | ||
| dtype.pyarrow_dtype, pa.lib.StructType | ||
|
||
| ): | ||
| self._dtype = dtype | ||
|
|
||
| return self | ||
|
|
||
| @classmethod | ||
| def from_arrow(cls, array: pa.Array | pa.ChunkedArray) -> Self: | ||
| if not isinstance(array, pa.ExtensionArray): | ||
|
|
@@ -76,6 +112,36 @@ def to_arrow(self) -> pa.Array: | |
| struct_arrow = pa.array([], typ.storage_type) | ||
| return pa.ExtensionArray.from_storage(typ, struct_arrow) | ||
|
|
||
| @classmethod | ||
| def _deserialize_plc_column( | ||
| cls, | ||
| header: dict, | ||
| dtype: DtypeObj, | ||
| data: Buffer | None, | ||
| mask: Buffer | None, | ||
| children: list[plc.Column], | ||
| ) -> plc.Column: | ||
| """Construct plc.Column using STRUCT type for interval columns.""" | ||
| offset = header.get("offset", 0) | ||
| if mask is None: | ||
| null_count = 0 | ||
| else: | ||
| null_count = plc.null_mask.null_count( | ||
| mask, offset, header["size"] + offset | ||
| ) | ||
|
|
||
| plc_type = plc.DataType(plc.TypeId.STRUCT) | ||
| return plc.Column( | ||
| plc_type, | ||
| header["size"], | ||
| data, | ||
| mask, | ||
| null_count, | ||
| offset, | ||
| children, | ||
| validate=False, | ||
| ) | ||
|
|
||
| def copy(self, deep: bool = True) -> Self: | ||
| return super().copy(deep=deep)._with_type_metadata(self.dtype) # type: ignore[return-value] | ||
|
|
||
|
|
@@ -134,6 +200,12 @@ def right(self) -> ColumnBase: | |
| self.plc_column.children()[1] | ||
| )._with_type_metadata(self.dtype.subtype) # type: ignore[union-attr] | ||
|
|
||
| @property | ||
| def __cuda_array_interface__(self) -> dict[str, Any]: | ||
| raise NotImplementedError( | ||
| "Intervals are not yet supported via `__cuda_array_interface__`" | ||
| ) | ||
|
|
||
| def overlaps(other) -> ColumnBase: | ||
| raise NotImplementedError("overlaps is not currently implemented.") | ||
|
|
||
|
|
@@ -176,6 +248,9 @@ def element_indexing( | |
| self, index: int | ||
| ) -> pd.Interval | dict[Any, Any] | None: | ||
| result = super().element_indexing(index) | ||
| if isinstance(result, pa.Scalar): | ||
| py_element = maybe_nested_pa_scalar_to_py(result) | ||
| result = self.dtype._recursively_replace_fields(py_element) # type: ignore[union-attr] | ||
| if isinstance(result, dict) and cudf.get_option( | ||
| "mode.pandas_compatible" | ||
| ): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,17 +53,14 @@ def _validate_args( # type: ignore[override] | |
| cls, plc_column: plc.Column, dtype: StructDtype | ||
| ) -> tuple[plc.Column, StructDtype]: | ||
| plc_column, dtype = super()._validate_args(plc_column, dtype) # type: ignore[assignment] | ||
| # IntervalDtype is a subclass of StructDtype, so compare types exactly | ||
| if ( | ||
| not cudf.get_option("mode.pandas_compatible") | ||
| and type(dtype) is not StructDtype | ||
| and not isinstance(dtype, StructDtype) | ||
| ) or ( | ||
| cudf.get_option("mode.pandas_compatible") | ||
| and not is_dtype_obj_struct(dtype) | ||
| ): | ||
| raise ValueError( | ||
| f"{type(dtype).__name__} must be a StructDtype exactly." | ||
| ) | ||
| raise ValueError(f"{type(dtype).__name__} must be a StructDtype.") | ||
| return plc_column, dtype | ||
|
|
||
| def _get_sliced_child(self, idx: int) -> ColumnBase: | ||
|
|
@@ -148,15 +145,15 @@ def __cuda_array_interface__(self) -> Mapping[str, Any]: | |
| "Structs are not yet supported via `__cuda_array_interface__`" | ||
| ) | ||
|
|
||
| def _with_type_metadata( | ||
| self: StructColumn, dtype: DtypeObj | ||
| ) -> StructColumn: | ||
| from cudf.core.column import IntervalColumn | ||
| def _with_type_metadata(self: StructColumn, dtype: DtypeObj) -> ColumnBase: | ||
| from cudf.core.dtypes import IntervalDtype | ||
|
|
||
| # Check IntervalDtype first because it's a subclass of StructDtype | ||
| if isinstance(dtype, IntervalDtype): | ||
| # TODO: Rewrite this to avoid needing to round-trip via ColumnBase | ||
| # Dispatch to IntervalColumn when given IntervalDtype | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this entire branch just call |
||
| from cudf.core.column.interval import IntervalColumn | ||
|
|
||
| # Apply subtype metadata to children and reconstruct as IntervalColumn | ||
| new_children = tuple( | ||
| ColumnBase.from_pylibcudf(child).astype(dtype.subtype) | ||
| for child in self.plc_column.children() | ||
|
|
@@ -175,7 +172,29 @@ def _with_type_metadata( | |
| dtype=dtype, | ||
| ) | ||
| elif isinstance(dtype, StructDtype): | ||
| self._dtype = dtype | ||
| new_children = tuple( | ||
| ColumnBase.from_pylibcudf(child)._with_type_metadata( | ||
| dtype.fields[f] | ||
| ) | ||
| for child, f in zip( | ||
| self.plc_column.children(), | ||
| dtype.fields.keys(), | ||
| strict=True, | ||
| ) | ||
| ) | ||
| new_plc_column = plc.Column( | ||
| plc.DataType(plc.TypeId.STRUCT), | ||
| self.plc_column.size(), | ||
| self.plc_column.data(), | ||
| self.plc_column.null_mask(), | ||
| self.plc_column.null_count(), | ||
| self.plc_column.offset(), | ||
| [child.plc_column for child in new_children], | ||
| ) | ||
| return StructColumn._from_preprocessed( | ||
| plc_column=new_plc_column, | ||
| dtype=dtype, | ||
| ) | ||
| # For pandas dtypes, store them directly in the column's dtype property | ||
| elif isinstance(dtype, pd.ArrowDtype) and isinstance( | ||
| dtype.pyarrow_dtype, pa.StructType | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
cudf.dtypewill be called in theIntervalDtypeconstructor as it would be nice to have less places usecudf.dtype