diff --git a/docs/cudf/source/user_guide/api_docs/index_objects.rst b/docs/cudf/source/user_guide/api_docs/index_objects.rst index 9c84f206010..ddf6c69af7d 100644 --- a/docs/cudf/source/user_guide/api_docs/index_objects.rst +++ b/docs/cudf/source/user_guide/api_docs/index_objects.rst @@ -110,6 +110,8 @@ Conversion Index.to_frame Index.to_pandas Index.to_dlpack + Index.to_pylibcudf + Index.from_pylibcudf Index.from_pandas Index.from_arrow diff --git a/python/cudf/benchmarks/API/bench_dataframe.py b/python/cudf/benchmarks/API/bench_dataframe.py index 59bac871057..c39da43cbfd 100644 --- a/python/cudf/benchmarks/API/bench_dataframe.py +++ b/python/cudf/benchmarks/API/bench_dataframe.py @@ -349,3 +349,11 @@ def bench_nsmallest(benchmark, dataframe, num_cols_to_sort, n): ) def bench_where(benchmark, dataframe, cond, other): benchmark(dataframe.where, cond, other) + + +@benchmark_with_object( + cls="dataframe", dtype="float", nulls=False, cols=20, rows=20 +) +@pytest.mark.pandas_incompatible +def bench_to_cupy(benchmark, dataframe): + benchmark(dataframe.to_cupy) diff --git a/python/cudf/benchmarks/API/bench_series.py b/python/cudf/benchmarks/API/bench_series.py index 3516888838e..5b82c38d0d0 100644 --- a/python/cudf/benchmarks/API/bench_series.py +++ b/python/cudf/benchmarks/API/bench_series.py @@ -23,13 +23,17 @@ def bench_series_nsmallest(benchmark, series, n): benchmark(series.nsmallest, n) -@benchmark_with_object(cls="series", dtype="int") +@benchmark_with_object(cls="series", dtype="int", nulls=False) def bench_series_cp_asarray(benchmark, series): - series = series.dropna() benchmark(cupy.asarray, series) -@benchmark_with_object(cls="series", dtype="int") +@benchmark_with_object(cls="series", dtype="int", nulls=False) +@pytest.mark.pandas_incompatible +def bench_to_cupy(benchmark, series): + benchmark(lambda: series.values) + + +@benchmark_with_object(cls="series", dtype="int", nulls=False) def bench_series_values(benchmark, series): - series = series.dropna() benchmark(lambda: series.values) diff --git a/python/cudf/benchmarks/common/config.py b/python/cudf/benchmarks/common/config.py index 872ba424d20..0f4e122839c 100644 --- a/python/cudf/benchmarks/common/config.py +++ b/python/cudf/benchmarks/common/config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. """Module used for global configuration of benchmarks. @@ -64,7 +64,7 @@ def pytest_sessionfinish(session, exitstatus): # Constants used to define benchmarking standards. if "CUDF_BENCHMARKS_DEBUG_ONLY" in os.environ: NUM_ROWS = [10, 20] - NUM_COLS = [1, 6] + NUM_COLS = [1, 6, 20] else: NUM_ROWS = [100, 10_000, 1_000_000] - NUM_COLS = [1, 6] + NUM_COLS = [1, 6, 20] diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index bd1bd628dee..4e3434f433e 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -587,6 +587,15 @@ def to_array( matrix[:, i] = to_array(col, dtype) return matrix + @_performance_tracking + def to_pylibcudf(self) -> tuple[plc.Table, dict[str, Any]]: + """ + Converts Frame to a pylibcudf.Table. + Note: This method should not be called directly on a Frame object + Instead, it should be called on subclasses like DataFrame/Series. + """ + raise NotImplementedError(f"{type(self)} must implement to_pylibcudf") + @_performance_tracking def to_cupy( self, @@ -613,6 +622,51 @@ def to_cupy( ------- cupy.ndarray """ + if ( + self._num_columns > 1 + and na_value is None + and self._columns[0].dtype.kind in {"i", "u", "f", "b"} + and all( + not col.nullable and col.dtype == self._columns[0].dtype + for col in self._columns + ) + ): + if dtype is None: + dtype = self._columns[0].dtype + + shape = (len(self), self._num_columns) + out = cupy.empty(shape, dtype=dtype, order="F") + + table = plc.Table( + [col.to_pylibcudf(mode="read") for col in self._columns] + ) + plc.reshape.table_to_array( + table, + out.data.ptr, + out.nbytes, + ) + return out + elif self._num_columns == 1: + col = self._columns[0] + final_dtype = col.dtype if dtype is None else dtype + + if ( + not copy + and col.dtype.kind in {"i", "u", "f", "b"} + and cupy.can_cast(col.dtype, final_dtype) + ): + if col.has_nulls(): + if na_value is not None: + col = col.fillna(na_value) + else: + return self._to_array( + lambda col: col.values, + cupy, + copy, + dtype, + na_value, + ) + return cupy.asarray(col, dtype=final_dtype).reshape((-1, 1)) return self._to_array( lambda col: col.values, cupy, diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index cb93f7f1328..a037c790911 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -323,6 +323,74 @@ def _from_data(cls, data: MutableMapping, name: Any = no_default) -> Self: def _from_data_like_self(self, data: MutableMapping) -> Self: return _index_from_data(data, self.name) + @_performance_tracking + def to_pylibcudf(self, copy=False) -> tuple[plc.Column, dict]: + """ + Convert this Index to a pylibcudf.Column. + + Parameters + ---------- + copy : bool + Whether or not to generate a new copy of the underlying device data + + Returns + ------- + pylibcudf.Column + A new pylibcudf.Column referencing the same data. + dict + Dict of metadata (includes name) + + Notes + ----- + User requests to convert to pylibcudf must assume that the + data may be modified afterwards. + """ + if copy: + raise NotImplementedError("copy=True is not supported") + metadata = {"name": self.name} + return self._column.to_pylibcudf(mode="write"), metadata + + @classmethod + @_performance_tracking + def from_pylibcudf( + cls, col: plc.Column, metadata: dict | None = None + ) -> Self: + """ + Create a Index from a pylibcudf.Column. + + Parameters + ---------- + col : pylibcudf.Column + The input Column. + + Returns + ------- + pylibcudf.Column + A new pylibcudf.Column referencing the same data. + metadata : dict | None + The Index metadata. + + Notes + ----- + This function will generate an Index which contains a Column + pointing to the provided pylibcudf Column. It will directly access + the data and mask buffers of the pylibcudf Column, so the newly created + object is not tied to the lifetime of the original pylibcudf.Column. + """ + name = None + if metadata is not None: + if not ( + isinstance(metadata, dict) + and len(metadata) == 1 + and set(metadata) == {"name"} + ): + raise ValueError("Metadata dict must only contain a name") + name = metadata.get("name") + return cls._from_column( + ColumnBase.from_pylibcudf(col, data_ptr_exposed=True), + name=name, + ) + @classmethod @_performance_tracking def from_arrow(cls, obj: pa.Array) -> Index | cudf.MultiIndex: diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index a098840346a..44aa54175b5 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -3820,6 +3820,8 @@ def from_pylibcudf( ---------- col : pylibcudf.Column The input Column. + metadata : dict | None + The Series metadata. Returns ------- diff --git a/python/cudf/cudf/core/single_column_frame.py b/python/cudf/cudf/core/single_column_frame.py index cc7a5fc2e8f..3cb1d38fd01 100644 --- a/python/cudf/cudf/core/single_column_frame.py +++ b/python/cudf/cudf/core/single_column_frame.py @@ -139,26 +139,11 @@ def to_cupy( ------- cupy.ndarray """ - col = self._column - final_dtype = ( - col.dtype if dtype is None else dtype - ) # some types do not support | operator - if ( - not copy - and col.dtype.kind in {"i", "u", "f", "b"} - and cp.can_cast(col.dtype, final_dtype) - and not col.has_nulls() - ): - if col.has_nulls(): - if na_value is not None: - col = col.fillna(na_value) - else: - return super().to_cupy( - dtype=dtype, copy=copy, na_value=na_value - ) - return cp.asarray(col, dtype=final_dtype) - - return super().to_cupy(dtype=dtype, copy=copy, na_value=na_value) + return ( + super() + .to_cupy(dtype=dtype, copy=copy, na_value=na_value) + .reshape(len(self), order="F") + ) @property # type: ignore @_performance_tracking diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index c434d0b1f78..7120736fce2 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1270,6 +1270,34 @@ def test_dataframe_to_cupy(): np.testing.assert_array_equal(df[k].to_numpy(), mat[:, i]) +@pytest.mark.parametrize("has_nulls", [False, True]) +@pytest.mark.parametrize("use_na_value", [False, True]) +def test_dataframe_to_cupy_single_column(has_nulls, use_na_value): + nelem = 10 + data = np.arange(nelem, dtype=np.float64) + + if has_nulls: + data = data.astype("object") + data[::2] = None + + df = cudf.DataFrame({"a": data}) + + if has_nulls and not use_na_value: + with pytest.raises(ValueError, match="Column must have no nulls"): + df.to_cupy() + return + + na_value = 0.0 if use_na_value else None + expected = ( + cupy.asarray(df["a"].fillna(na_value)) + if has_nulls + else cupy.asarray(df["a"]) + ) + result = df.to_cupy(na_value=na_value) + assert result.shape == (nelem, 1) + assert_eq(result.ravel(), expected) + + def test_dataframe_to_cupy_null_values(): df = cudf.DataFrame() diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 34fd20f71c8..ddc077fba86 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -3343,3 +3343,10 @@ def test_categoricalindex_from_codes(ordered, name): name=name, ) assert_eq(result, expected) + + +def test_roundtrip_index_plc_column(): + index = cudf.Index([1]) + expect = cudf.Index(index) + actual = cudf.Index.from_pylibcudf(*expect.to_pylibcudf()) + assert_eq(expect, actual) diff --git a/python/cudf/cudf/tests/test_series.py b/python/cudf/cudf/tests/test_series.py index d0edd3d9646..c23f65d8dd4 100644 --- a/python/cudf/cudf/tests/test_series.py +++ b/python/cudf/cudf/tests/test_series.py @@ -3091,6 +3091,7 @@ def test_series_to_cupy(dtype, has_nulls, use_na_value): if not has_nulls: assert_eq(sr.values, cp.asarray(sr)) + return if has_nulls and not use_na_value: with pytest.raises(ValueError, match="Column must have no nulls"): diff --git a/python/pylibcudf/pylibcudf/libcudf/reshape.pxd b/python/pylibcudf/pylibcudf/libcudf/reshape.pxd index 92ab4773940..84fcd7cdaa8 100644 --- a/python/pylibcudf/pylibcudf/libcudf/reshape.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/reshape.pxd @@ -1,10 +1,17 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from pylibcudf.exception_handler cimport libcudf_exception_handler from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.table.table cimport table from pylibcudf.libcudf.table.table_view cimport table_view -from pylibcudf.libcudf.types cimport size_type +from pylibcudf.libcudf.types cimport size_type, data_type +from pylibcudf.libcudf.utilities.span cimport device_span + +from rmm.librmm.cuda_stream_view cimport cuda_stream_view + +cdef extern from "cuda/functional" namespace "cuda::std": + cdef cppclass byte: + pass cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil: @@ -14,3 +21,8 @@ cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil: cdef unique_ptr[table] tile( table_view source_table, size_type count ) except +libcudf_exception_handler + cdef void table_to_array( + table_view input_table, + device_span[byte] output, + cuda_stream_view stream + ) except +libcudf_exception_handler diff --git a/python/pylibcudf/pylibcudf/reshape.pxd b/python/pylibcudf/pylibcudf/reshape.pxd index c4d3d375f7a..efb7217f0e8 100644 --- a/python/pylibcudf/pylibcudf/reshape.pxd +++ b/python/pylibcudf/pylibcudf/reshape.pxd @@ -1,11 +1,24 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. + +from libc.stddef cimport size_t +from libc.stdint cimport uintptr_t from pylibcudf.libcudf.types cimport size_type +from rmm.pylibrmm.stream cimport Stream +from rmm.pylibrmm.device_buffer cimport DeviceBuffer + from .column cimport Column from .scalar cimport Scalar from .table cimport Table +from .types cimport DataType cpdef Column interleave_columns(Table source_table) cpdef Table tile(Table source_table, size_type count) +cpdef void table_to_array( + Table input_table, + uintptr_t ptr, + size_t size, + Stream stream=* +) diff --git a/python/pylibcudf/pylibcudf/reshape.pyi b/python/pylibcudf/pylibcudf/reshape.pyi index d8d0ffcc3e0..8f2e33903af 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyi +++ b/python/pylibcudf/pylibcudf/reshape.pyi @@ -1,7 +1,15 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from rmm.pylibrmm.stream import Stream + from pylibcudf.column import Column from pylibcudf.table import Table def interleave_columns(source_table: Table) -> Column: ... def tile(source_table: Table, count: int) -> Table: ... +def table_to_array( + input_table: Table, + ptr: int, + size: int, + stream: Stream, +) -> None: ... diff --git a/python/pylibcudf/pylibcudf/reshape.pyx b/python/pylibcudf/pylibcudf/reshape.pyx index bdc212a1985..0ebe61af713 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyx +++ b/python/pylibcudf/pylibcudf/reshape.pyx @@ -1,19 +1,29 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +from libc.stddef cimport size_t +from libc.stdint cimport uintptr_t from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from libcpp.limits cimport numeric_limits from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.reshape cimport ( interleave_columns as cpp_interleave_columns, tile as cpp_tile, + table_to_array as cpp_table_to_array, + byte, ) from pylibcudf.libcudf.table.table cimport table from pylibcudf.libcudf.types cimport size_type +from pylibcudf.libcudf.utilities.span cimport device_span + +from rmm.pylibrmm.stream cimport Stream + from .column cimport Column from .table cimport Table +from .utils cimport _get_stream -__all__ = ["interleave_columns", "tile"] +__all__ = ["interleave_columns", "tile", "table_to_array"] cpdef Column interleave_columns(Table source_table): """Interleave columns of a table into a single column. @@ -67,3 +77,42 @@ cpdef Table tile(Table source_table, size_type count): c_result = cpp_tile(source_table.view(), count) return Table.from_libcudf(move(c_result)) + + +cpdef void table_to_array( + Table input_table, + uintptr_t ptr, + size_t size, + Stream stream=None +): + """ + Copy a table into a preallocated column-major device array. + + Parameters + ---------- + input_table : Table + A table with fixed-width, non-nullable columns of the same type. + ptr : uintptr_t + A device pointer to the beginning of the output buffer. + size : size_t + The total number of bytes available at `ptr`. + Must be at least `num_rows * num_columns * sizeof(dtype)`. + stream : Stream | None + CUDA stream on which to perform the operation. + """ + if size > numeric_limits[size_t].max(): + raise ValueError( + "Size exceeds the size_t limit." + ) + stream = _get_stream(stream) + + cdef device_span[byte] span = device_span[byte]( + ptr, size + ) + + with nogil: + cpp_table_to_array( + input_table.view(), + span, + stream.view() + ) diff --git a/python/pylibcudf/pylibcudf/tests/test_reshape.py b/python/pylibcudf/pylibcudf/tests/test_reshape.py index f77277c0c1e..3802958e5c4 100644 --- a/python/pylibcudf/pylibcudf/tests/test_reshape.py +++ b/python/pylibcudf/pylibcudf/tests/test_reshape.py @@ -1,10 +1,12 @@ # Copyright (c) 2024-2025, NVIDIA CORPORATION. +import cupy as cp import pyarrow as pa import pytest from utils import assert_column_eq, assert_table_eq import pylibcudf as plc +from pylibcudf.types import TypeId @pytest.fixture(scope="module") @@ -37,3 +39,34 @@ def test_tile(reshape_data, cnt): ) assert_table_eq(expect, got) + + +@pytest.mark.parametrize( + "dtype, type_id", + [ + ("int32", TypeId.INT32), + ("int64", TypeId.INT64), + ("float32", TypeId.FLOAT32), + ("float64", TypeId.FLOAT64), + ], +) +def test_table_to_array(dtype, type_id): + arrow_type = pa.from_numpy_dtype(getattr(cp, dtype)) + arrs = [ + pa.array([1, 2, 3], type=arrow_type), + pa.array([4, 5, 6], type=arrow_type), + ] + arrow_tbl = pa.Table.from_arrays(arrs, names=["a", "b"]) + tbl = plc.interop.from_arrow(arrow_tbl) + + rows, cols = tbl.num_rows(), tbl.num_columns() + got = cp.empty((rows, cols), dtype=dtype, order="F") + + plc.reshape.table_to_array( + tbl, + got.data.ptr, + got.nbytes, + ) + + expect = cp.array([[1, 4], [2, 5], [3, 6]], dtype=dtype) + cp.testing.assert_array_equal(expect, got)