Skip to content
2 changes: 2 additions & 0 deletions docs/cudf/source/user_guide/api_docs/index_objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ Conversion
Index.to_frame
Index.to_pandas
Index.to_dlpack
Index.to_pylibcudf
Index.from_pylibcudf
Index.from_pandas
Index.from_arrow

Expand Down
8 changes: 8 additions & 0 deletions python/cudf/benchmarks/API/bench_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,11 @@ def bench_nsmallest(benchmark, dataframe, num_cols_to_sort, n):
)
def bench_where(benchmark, dataframe, cond, other):
benchmark(dataframe.where, cond, other)


@benchmark_with_object(
cls="dataframe", dtype="float", nulls=False, cols=20, rows=20
)
@pytest.mark.pandas_incompatible
def bench_to_cupy(benchmark, dataframe):
benchmark(dataframe.to_cupy)
12 changes: 8 additions & 4 deletions python/cudf/benchmarks/API/bench_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ def bench_series_nsmallest(benchmark, series, n):
benchmark(series.nsmallest, n)


@benchmark_with_object(cls="series", dtype="int")
@benchmark_with_object(cls="series", dtype="int", nulls=False)
def bench_series_cp_asarray(benchmark, series):
series = series.dropna()
benchmark(cupy.asarray, series)


@benchmark_with_object(cls="series", dtype="int")
@benchmark_with_object(cls="series", dtype="int", nulls=False)
@pytest.mark.pandas_incompatible
def bench_to_cupy(benchmark, series):
benchmark(lambda: series.values)


@benchmark_with_object(cls="series", dtype="int", nulls=False)
def bench_series_values(benchmark, series):
series = series.dropna()
benchmark(lambda: series.values)
6 changes: 3 additions & 3 deletions python/cudf/benchmarks/common/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.

"""Module used for global configuration of benchmarks.

Expand Down Expand Up @@ -64,7 +64,7 @@ def pytest_sessionfinish(session, exitstatus):
# Constants used to define benchmarking standards.
if "CUDF_BENCHMARKS_DEBUG_ONLY" in os.environ:
NUM_ROWS = [10, 20]
NUM_COLS = [1, 6]
NUM_COLS = [1, 6, 20]
else:
NUM_ROWS = [100, 10_000, 1_000_000]
NUM_COLS = [1, 6]
NUM_COLS = [1, 6, 20]
54 changes: 54 additions & 0 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,15 @@ def to_array(
matrix[:, i] = to_array(col, dtype)
return matrix

@_performance_tracking
def to_pylibcudf(self) -> tuple[plc.Table, dict[str, Any]]:
"""
Converts Frame to a pylibcudf.Table.
Note: This method should not be called directly on a Frame object
Instead, it should be called on subclasses like DataFrame/Series.
"""
raise NotImplementedError(f"{type(self)} must implement to_pylibcudf")

@_performance_tracking
def to_cupy(
self,
Expand All @@ -613,6 +622,51 @@ def to_cupy(
-------
cupy.ndarray
"""
if (
self._num_columns > 1
and na_value is None
and self._columns[0].dtype.kind in {"i", "u", "f", "b"}
and all(
not col.nullable and col.dtype == self._columns[0].dtype
for col in self._columns
)
):
if dtype is None:
dtype = self._columns[0].dtype

shape = (len(self), self._num_columns)
out = cupy.empty(shape, dtype=dtype, order="F")

table = plc.Table(
[col.to_pylibcudf(mode="read") for col in self._columns]
)
plc.reshape.table_to_array(
table,
out.data.ptr,
out.nbytes,
)
return out
elif self._num_columns == 1:
col = self._columns[0]
final_dtype = col.dtype if dtype is None else dtype

if (
not copy
and col.dtype.kind in {"i", "u", "f", "b"}
and cupy.can_cast(col.dtype, final_dtype)
):
if col.has_nulls():
if na_value is not None:
col = col.fillna(na_value)
else:
return self._to_array(
lambda col: col.values,
cupy,
copy,
dtype,
na_value,
)
return cupy.asarray(col, dtype=final_dtype).reshape((-1, 1))
return self._to_array(
lambda col: col.values,
cupy,
Expand Down
68 changes: 68 additions & 0 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,74 @@ def _from_data(cls, data: MutableMapping, name: Any = no_default) -> Self:
def _from_data_like_self(self, data: MutableMapping) -> Self:
return _index_from_data(data, self.name)

@_performance_tracking
def to_pylibcudf(self, copy=False) -> tuple[plc.Column, dict]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs failed saying Index.to_pylibcudf was missing. I went ahead and added both to_plc and from_plc in this PR. They are almost identical to Series.from/to_plc. In a follow up PR, I can centralize these methods since both Index and Series are SingleColumnFrames

"""
Convert this Index to a pylibcudf.Column.

Parameters
----------
copy : bool
Whether or not to generate a new copy of the underlying device data

Returns
-------
pylibcudf.Column
A new pylibcudf.Column referencing the same data.
dict
Dict of metadata (includes name)

Notes
-----
User requests to convert to pylibcudf must assume that the
data may be modified afterwards.
"""
if copy:
raise NotImplementedError("copy=True is not supported")
metadata = {"name": self.name}
return self._column.to_pylibcudf(mode="write"), metadata

@classmethod
@_performance_tracking
def from_pylibcudf(
cls, col: plc.Column, metadata: dict | None = None
) -> Self:
"""
Create a Index from a pylibcudf.Column.

Parameters
----------
col : pylibcudf.Column
The input Column.

Returns
-------
pylibcudf.Column
A new pylibcudf.Column referencing the same data.
metadata : dict | None
The Index metadata.

Notes
-----
This function will generate an Index which contains a Column
pointing to the provided pylibcudf Column. It will directly access
the data and mask buffers of the pylibcudf Column, so the newly created
object is not tied to the lifetime of the original pylibcudf.Column.
"""
name = None
if metadata is not None:
if not (
isinstance(metadata, dict)
and len(metadata) == 1
and set(metadata) == {"name"}
):
raise ValueError("Metadata dict must only contain a name")
name = metadata.get("name")
return cls._from_column(
ColumnBase.from_pylibcudf(col, data_ptr_exposed=True),
name=name,
)

@classmethod
@_performance_tracking
def from_arrow(cls, obj: pa.Array) -> Index | cudf.MultiIndex:
Expand Down
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3820,6 +3820,8 @@ def from_pylibcudf(
----------
col : pylibcudf.Column
The input Column.
metadata : dict | None
The Series metadata.

Returns
-------
Expand Down
25 changes: 5 additions & 20 deletions python/cudf/cudf/core/single_column_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,26 +139,11 @@ def to_cupy(
-------
cupy.ndarray
"""
col = self._column
final_dtype = (
col.dtype if dtype is None else dtype
) # some types do not support | operator
if (
not copy
and col.dtype.kind in {"i", "u", "f", "b"}
and cp.can_cast(col.dtype, final_dtype)
and not col.has_nulls()
):
if col.has_nulls():
if na_value is not None:
col = col.fillna(na_value)
else:
return super().to_cupy(
dtype=dtype, copy=copy, na_value=na_value
)
return cp.asarray(col, dtype=final_dtype)

return super().to_cupy(dtype=dtype, copy=copy, na_value=na_value)
return (
super()
.to_cupy(dtype=dtype, copy=copy, na_value=na_value)
.reshape(len(self), order="F")
)

@property # type: ignore
@_performance_tracking
Expand Down
28 changes: 28 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,34 @@ def test_dataframe_to_cupy():
np.testing.assert_array_equal(df[k].to_numpy(), mat[:, i])


@pytest.mark.parametrize("has_nulls", [False, True])
@pytest.mark.parametrize("use_na_value", [False, True])
def test_dataframe_to_cupy_single_column(has_nulls, use_na_value):
nelem = 10
data = np.arange(nelem, dtype=np.float64)

if has_nulls:
data = data.astype("object")
data[::2] = None

df = cudf.DataFrame({"a": data})

if has_nulls and not use_na_value:
with pytest.raises(ValueError, match="Column must have no nulls"):
df.to_cupy()
return

na_value = 0.0 if use_na_value else None
expected = (
cupy.asarray(df["a"].fillna(na_value))
if has_nulls
else cupy.asarray(df["a"])
)
result = df.to_cupy(na_value=na_value)
assert result.shape == (nelem, 1)
assert_eq(result.ravel(), expected)


def test_dataframe_to_cupy_null_values():
df = cudf.DataFrame()

Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3343,3 +3343,10 @@ def test_categoricalindex_from_codes(ordered, name):
name=name,
)
assert_eq(result, expected)


def test_roundtrip_index_plc_column():
index = cudf.Index([1])
expect = cudf.Index(index)
actual = cudf.Index.from_pylibcudf(*expect.to_pylibcudf())
assert_eq(expect, actual)
1 change: 1 addition & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,6 +3091,7 @@ def test_series_to_cupy(dtype, has_nulls, use_na_value):

if not has_nulls:
assert_eq(sr.values, cp.asarray(sr))
return

if has_nulls and not use_na_value:
with pytest.raises(ValueError, match="Column must have no nulls"):
Expand Down
16 changes: 14 additions & 2 deletions python/pylibcudf/pylibcudf/libcudf/reshape.pxd
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
from libcpp.memory cimport unique_ptr
from pylibcudf.exception_handler cimport libcudf_exception_handler
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.table.table cimport table
from pylibcudf.libcudf.table.table_view cimport table_view
from pylibcudf.libcudf.types cimport size_type
from pylibcudf.libcudf.types cimport size_type, data_type
from pylibcudf.libcudf.utilities.span cimport device_span

from rmm.librmm.cuda_stream_view cimport cuda_stream_view

cdef extern from "cuda/functional" namespace "cuda::std":
cdef cppclass byte:
pass


cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil:
Expand All @@ -14,3 +21,8 @@ cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil:
cdef unique_ptr[table] tile(
table_view source_table, size_type count
) except +libcudf_exception_handler
cdef void table_to_array(
table_view input_table,
device_span[byte] output,
cuda_stream_view stream
) except +libcudf_exception_handler
15 changes: 14 additions & 1 deletion python/pylibcudf/pylibcudf/reshape.pxd
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.

from libc.stddef cimport size_t
from libc.stdint cimport uintptr_t

from pylibcudf.libcudf.types cimport size_type

from rmm.pylibrmm.stream cimport Stream
from rmm.pylibrmm.device_buffer cimport DeviceBuffer

from .column cimport Column
from .scalar cimport Scalar
from .table cimport Table
from .types cimport DataType


cpdef Column interleave_columns(Table source_table)
cpdef Table tile(Table source_table, size_type count)
cpdef void table_to_array(
Table input_table,
uintptr_t ptr,
size_t size,
Stream stream=*
)
8 changes: 8 additions & 0 deletions python/pylibcudf/pylibcudf/reshape.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rmm.pylibrmm.stream import Stream

from pylibcudf.column import Column
from pylibcudf.table import Table

def interleave_columns(source_table: Table) -> Column: ...
def tile(source_table: Table, count: int) -> Table: ...
def table_to_array(
input_table: Table,
ptr: int,
size: int,
stream: Stream,
) -> None: ...
Loading
Loading