From c539124cb1d6eaac56aa82317f22b6ef228aa000 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Tue, 13 May 2025 13:21:27 -0400 Subject: [PATCH 01/12] Add fast paths for DataFrame.to_cupy --- python/cudf/benchmarks/API/bench_dataframe.py | 7 +++ python/cudf/benchmarks/common/config.py | 6 +-- python/cudf/cudf/core/frame.py | 41 +++++++++++++++ .../pylibcudf/pylibcudf/libcudf/reshape.pxd | 16 +++++- python/pylibcudf/pylibcudf/reshape.pxd | 14 ++++- python/pylibcudf/pylibcudf/reshape.pyi | 8 +++ python/pylibcudf/pylibcudf/reshape.pyx | 51 ++++++++++++++++++- .../pylibcudf/pylibcudf/tests/test_reshape.py | 44 +++++++++++++++- 8 files changed, 178 insertions(+), 9 deletions(-) diff --git a/python/cudf/benchmarks/API/bench_dataframe.py b/python/cudf/benchmarks/API/bench_dataframe.py index 59bac871057..28c319db94b 100644 --- a/python/cudf/benchmarks/API/bench_dataframe.py +++ b/python/cudf/benchmarks/API/bench_dataframe.py @@ -349,3 +349,10 @@ def bench_nsmallest(benchmark, dataframe, num_cols_to_sort, n): ) def bench_where(benchmark, dataframe, cond, other): benchmark(dataframe.where, cond, other) + + +@benchmark_with_object(cls="dataframe", dtype="float", nulls=False, cols=20) +@pytest.mark.parametrize("fast", [True, False]) +@pytest.mark.pandas_incompatible +def bench_to_cupy(benchmark, dataframe, fast): + benchmark(dataframe.to_cupy, fast=fast) diff --git a/python/cudf/benchmarks/common/config.py b/python/cudf/benchmarks/common/config.py index 872ba424d20..0f4e122839c 100644 --- a/python/cudf/benchmarks/common/config.py +++ b/python/cudf/benchmarks/common/config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. """Module used for global configuration of benchmarks. @@ -64,7 +64,7 @@ def pytest_sessionfinish(session, exitstatus): # Constants used to define benchmarking standards. if "CUDF_BENCHMARKS_DEBUG_ONLY" in os.environ: NUM_ROWS = [10, 20] - NUM_COLS = [1, 6] + NUM_COLS = [1, 6, 20] else: NUM_ROWS = [100, 10_000, 1_000_000] - NUM_COLS = [1, 6] + NUM_COLS = [1, 6, 20] diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 1023877265a..6d2b337f683 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal import cupy +import cupy as cp import numpy import numpy as np import pyarrow as pa @@ -524,6 +525,15 @@ def to_array( matrix[:, i] = to_array(col, dtype) return matrix + @_performance_tracking + def to_pylibcudf(self) -> tuple[plc.Table, dict[str, Any]]: + """ + Converts Frame to a pylibcudf.Table. + Note: This method should not be called directly on a Frame object + Instead, it should be called on subclasses like DataFrame/Series. + """ + raise NotImplementedError(f"{type(self)} must implement to_pylibcudf") + @_performance_tracking def to_cupy( self, @@ -550,6 +560,37 @@ def to_cupy( ------- cupy.ndarray """ + if ( + self._num_columns > 1 + and na_value is None + and not isinstance(self._columns[0].dtype, cudf.CategoricalDtype) + and all( + not col.nullable and col.dtype == self._columns[0].dtype + for col in self._columns + ) + ): + if dtype is None: + dtype = np.dtype(self._columns[0].dtype) + + shape = (len(self), self._num_columns) + out = cupy.empty(shape, dtype=dtype, order="F") + + table = self.to_pylibcudf()[0] + if isinstance(table, plc.Column): + table = plc.Table([table]) + plc.reshape.table_to_array( + table, + out.data.ptr, + out.nbytes, + ) + return out + elif ( + self._num_columns == 1 + and na_value is None + and not isinstance(self._columns[0].dtype, cudf.CategoricalDtype) + and not self._columns[0].nullable + ): + return cp.asarray(self._columns[0]).reshape((-1, 1)) return self._to_array( lambda col: col.values, cupy, diff --git a/python/pylibcudf/pylibcudf/libcudf/reshape.pxd b/python/pylibcudf/pylibcudf/libcudf/reshape.pxd index 92ab4773940..84fcd7cdaa8 100644 --- a/python/pylibcudf/pylibcudf/libcudf/reshape.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/reshape.pxd @@ -1,10 +1,17 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from pylibcudf.exception_handler cimport libcudf_exception_handler from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.table.table cimport table from pylibcudf.libcudf.table.table_view cimport table_view -from pylibcudf.libcudf.types cimport size_type +from pylibcudf.libcudf.types cimport size_type, data_type +from pylibcudf.libcudf.utilities.span cimport device_span + +from rmm.librmm.cuda_stream_view cimport cuda_stream_view + +cdef extern from "cuda/functional" namespace "cuda::std": + cdef cppclass byte: + pass cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil: @@ -14,3 +21,8 @@ cdef extern from "cudf/reshape.hpp" namespace "cudf" nogil: cdef unique_ptr[table] tile( table_view source_table, size_type count ) except +libcudf_exception_handler + cdef void table_to_array( + table_view input_table, + device_span[byte] output, + cuda_stream_view stream + ) except +libcudf_exception_handler diff --git a/python/pylibcudf/pylibcudf/reshape.pxd b/python/pylibcudf/pylibcudf/reshape.pxd index c4d3d375f7a..1e7aa957c1a 100644 --- a/python/pylibcudf/pylibcudf/reshape.pxd +++ b/python/pylibcudf/pylibcudf/reshape.pxd @@ -1,11 +1,23 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. + +from libc.stdint cimport uintptr_t from pylibcudf.libcudf.types cimport size_type +from rmm.pylibrmm.stream cimport Stream +from rmm.pylibrmm.device_buffer cimport DeviceBuffer + from .column cimport Column from .scalar cimport Scalar from .table cimport Table +from .types cimport DataType cpdef Column interleave_columns(Table source_table) cpdef Table tile(Table source_table, size_type count) +cpdef void table_to_array( + Table input_table, + uintptr_t ptr, + size_type size, + Stream stream=* +) diff --git a/python/pylibcudf/pylibcudf/reshape.pyi b/python/pylibcudf/pylibcudf/reshape.pyi index d8d0ffcc3e0..8f2e33903af 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyi +++ b/python/pylibcudf/pylibcudf/reshape.pyi @@ -1,7 +1,15 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from rmm.pylibrmm.stream import Stream + from pylibcudf.column import Column from pylibcudf.table import Table def interleave_columns(source_table: Table) -> Column: ... def tile(source_table: Table, count: int) -> Table: ... +def table_to_array( + input_table: Table, + ptr: int, + size: int, + stream: Stream, +) -> None: ... diff --git a/python/pylibcudf/pylibcudf/reshape.pyx b/python/pylibcudf/pylibcudf/reshape.pyx index bdc212a1985..7ece4b105ae 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyx +++ b/python/pylibcudf/pylibcudf/reshape.pyx @@ -1,19 +1,28 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +from libc.stdint cimport uintptr_t from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from libcpp.limits cimport numeric_limits from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.reshape cimport ( interleave_columns as cpp_interleave_columns, tile as cpp_tile, + table_to_array as cpp_table_to_array, + byte, ) from pylibcudf.libcudf.table.table cimport table from pylibcudf.libcudf.types cimport size_type +from pylibcudf.libcudf.utilities.span cimport device_span + +from rmm.pylibrmm.stream cimport Stream + from .column cimport Column from .table cimport Table +from .utils cimport _get_stream -__all__ = ["interleave_columns", "tile"] +__all__ = ["interleave_columns", "tile", "table_to_array"] cpdef Column interleave_columns(Table source_table): """Interleave columns of a table into a single column. @@ -67,3 +76,41 @@ cpdef Table tile(Table source_table, size_type count): c_result = cpp_tile(source_table.view(), count) return Table.from_libcudf(move(c_result)) + + +cpdef void table_to_array( + Table input_table, + uintptr_t ptr, + size_type size, + Stream stream=None +): + """ + Copy a table into a preallocated column-major device array. + Parameters + ---------- + input_table : Table + A table with fixed-width, non-nullable columns of the same type. + ptr : uintptr_t + A device pointer to the beginning of the output buffer. + size : size_type + The total number of bytes available at `ptr`. + Must be at least `num_rows * num_columns * sizeof(dtype)`. + stream : Stream | None + CUDA stream on which to perform the operation. + """ + if size > numeric_limits[size_type].max(): + raise ValueError( + "Size exceeds the int32_t limit." + ) + stream = _get_stream(stream) + + cdef device_span[byte] span = device_span[byte]( + ptr, size + ) + + with nogil: + cpp_table_to_array( + input_table.view(), + span, + stream.view() + ) diff --git a/python/pylibcudf/pylibcudf/tests/test_reshape.py b/python/pylibcudf/pylibcudf/tests/test_reshape.py index ef23e23766a..718879dc202 100644 --- a/python/pylibcudf/pylibcudf/tests/test_reshape.py +++ b/python/pylibcudf/pylibcudf/tests/test_reshape.py @@ -1,10 +1,21 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. import pyarrow as pa import pytest from utils import assert_column_eq, assert_table_eq import pylibcudf as plc +from pylibcudf.types import TypeId + + +@pytest.fixture(scope="module") +def np(): + return pytest.importorskip("cupy") + + +@pytest.fixture(scope="module") +def cp(): + return pytest.importorskip("cupy") @pytest.fixture(scope="module") @@ -37,3 +48,34 @@ def test_tile(reshape_data, cnt): ) assert_table_eq(expect, res) + + +@pytest.mark.parametrize( + "dtype, type_id", + [ + ("int32", TypeId.INT32), + ("int64", TypeId.INT64), + ("float32", TypeId.FLOAT32), + ("float64", TypeId.FLOAT64), + ], +) +def test_table_to_array(dtype, type_id, np, cp): + arrow_type = pa.from_numpy_dtype(getattr(np, dtype)) + arrs = [ + pa.array([1, 2, 3], type=arrow_type), + pa.array([4, 5, 6], type=arrow_type), + ] + arrow_tbl = pa.Table.from_arrays(arrs, names=["a", "b"]) + tbl = plc.interop.from_arrow(arrow_tbl) + + rows, cols = tbl.num_rows(), tbl.num_columns() + out = cp.empty((rows, cols), dtype=dtype, order="F") + + plc.reshape.table_to_array( + tbl, + out.data.ptr, + out.nbytes, + ) + + expect = cp.array([[1, 4], [2, 5], [3, 6]], dtype=dtype) + cp.testing.assert_array_equal(out, expect) From 2c298df691ff34c19bc68b60951d99580fe24b5b Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Tue, 13 May 2025 13:56:21 -0400 Subject: [PATCH 02/12] clean up --- python/cudf/benchmarks/API/bench_dataframe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/cudf/benchmarks/API/bench_dataframe.py b/python/cudf/benchmarks/API/bench_dataframe.py index 28c319db94b..6caa4ba4762 100644 --- a/python/cudf/benchmarks/API/bench_dataframe.py +++ b/python/cudf/benchmarks/API/bench_dataframe.py @@ -352,7 +352,6 @@ def bench_where(benchmark, dataframe, cond, other): @benchmark_with_object(cls="dataframe", dtype="float", nulls=False, cols=20) -@pytest.mark.parametrize("fast", [True, False]) @pytest.mark.pandas_incompatible -def bench_to_cupy(benchmark, dataframe, fast): - benchmark(dataframe.to_cupy, fast=fast) +def bench_to_cupy(benchmark, dataframe): + benchmark(dataframe.to_cupy) From 59cee55e145da2f02f8178cfa9f975f280953271 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Tue, 13 May 2025 14:03:36 -0400 Subject: [PATCH 03/12] more clean up --- python/cudf/cudf/core/frame.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 6d2b337f683..71795bae228 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -563,7 +563,7 @@ def to_cupy( if ( self._num_columns > 1 and na_value is None - and not isinstance(self._columns[0].dtype, cudf.CategoricalDtype) + and self._columns[0].dtype.kind in {"i", "u", "f", "b"} and all( not col.nullable and col.dtype == self._columns[0].dtype for col in self._columns @@ -587,7 +587,7 @@ def to_cupy( elif ( self._num_columns == 1 and na_value is None - and not isinstance(self._columns[0].dtype, cudf.CategoricalDtype) + and self._columns[0].dtype.kind in {"i", "u", "f", "b"} and not self._columns[0].nullable ): return cp.asarray(self._columns[0]).reshape((-1, 1)) From 0c92c1842d761ea6f8ece07d512e4d59e054537b Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Tue, 13 May 2025 14:44:12 -0400 Subject: [PATCH 04/12] add tests and fix bug --- python/cudf/cudf/core/frame.py | 28 +++++++++++++++----- python/cudf/cudf/core/single_column_frame.py | 1 - python/cudf/cudf/tests/test_dataframe.py | 28 ++++++++++++++++++++ python/cudf/cudf/tests/test_series.py | 1 + 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 71795bae228..f4f8b4a6193 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -584,13 +584,27 @@ def to_cupy( out.nbytes, ) return out - elif ( - self._num_columns == 1 - and na_value is None - and self._columns[0].dtype.kind in {"i", "u", "f", "b"} - and not self._columns[0].nullable - ): - return cp.asarray(self._columns[0]).reshape((-1, 1)) + elif self._num_columns == 1: + col = self._columns[0] + final_dtype = col.dtype if dtype is None else dtype + + if ( + not copy + and col.dtype.kind in {"i", "u", "f", "b"} + and cp.can_cast(col.dtype, final_dtype) + ): + if col.has_nulls(): + if na_value is not None: + col = col.fillna(na_value) + else: + return self._to_array( + lambda col: col.values, + cupy, + copy, + dtype, + na_value, + ) + return cp.asarray(col, dtype=final_dtype).reshape((-1, 1)) return self._to_array( lambda col: col.values, cupy, diff --git a/python/cudf/cudf/core/single_column_frame.py b/python/cudf/cudf/core/single_column_frame.py index ad27ab01a61..fb83fc6eeee 100644 --- a/python/cudf/cudf/core/single_column_frame.py +++ b/python/cudf/cudf/core/single_column_frame.py @@ -147,7 +147,6 @@ def to_cupy( not copy and col.dtype.kind in {"i", "u", "f", "b"} and cp.can_cast(col.dtype, final_dtype) - and not col.has_nulls() ): if col.has_nulls(): if na_value is not None: diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 216f3fd9ab9..4a77cebc810 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1270,6 +1270,34 @@ def test_dataframe_to_cupy(): np.testing.assert_array_equal(df[k].to_numpy(), mat[:, i]) +@pytest.mark.parametrize("has_nulls", [False, True]) +@pytest.mark.parametrize("use_na_value", [False, True]) +def test_dataframe_to_cupy_single_column(has_nulls, use_na_value): + nelem = 10 + data = np.arange(nelem, dtype=np.float64) + + if has_nulls: + data = data.astype("object") + data[::2] = None + + df = cudf.DataFrame({"a": data}) + + if has_nulls and not use_na_value: + with pytest.raises(ValueError, match="Column must have no nulls"): + df.to_cupy() + return + + na_value = 0.0 if use_na_value else None + expected = ( + cupy.asarray(df["a"].fillna(na_value)) + if has_nulls + else cupy.asarray(df["a"]) + ) + result = df.to_cupy(na_value=na_value) + assert result.shape == (nelem, 1) + assert_eq(result.ravel(), expected) + + def test_dataframe_to_cupy_null_values(): df = cudf.DataFrame() diff --git a/python/cudf/cudf/tests/test_series.py b/python/cudf/cudf/tests/test_series.py index d0edd3d9646..c23f65d8dd4 100644 --- a/python/cudf/cudf/tests/test_series.py +++ b/python/cudf/cudf/tests/test_series.py @@ -3091,6 +3091,7 @@ def test_series_to_cupy(dtype, has_nulls, use_na_value): if not has_nulls: assert_eq(sr.values, cp.asarray(sr)) + return if has_nulls and not use_na_value: with pytest.raises(ValueError, match="Column must have no nulls"): From 366005d559135f06fc88a3a67f2eb995950bc19f Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 15:33:48 -0400 Subject: [PATCH 05/12] address review --- python/cudf/benchmarks/API/bench_dataframe.py | 4 +++- python/cudf/benchmarks/API/bench_series.py | 12 ++++++---- python/cudf/benchmarks/common/config.py | 4 ++-- python/cudf/benchmarks/conftest.py | 9 +------ python/cudf/cudf/core/frame.py | 5 ++-- python/cudf/cudf/core/single_column_frame.py | 24 ++++--------------- python/pylibcudf/pylibcudf/reshape.pxd | 3 ++- python/pylibcudf/pylibcudf/reshape.pyx | 6 +++-- .../pylibcudf/pylibcudf/tests/test_reshape.py | 16 ++++--------- 9 files changed, 31 insertions(+), 52 deletions(-) diff --git a/python/cudf/benchmarks/API/bench_dataframe.py b/python/cudf/benchmarks/API/bench_dataframe.py index 6caa4ba4762..c39da43cbfd 100644 --- a/python/cudf/benchmarks/API/bench_dataframe.py +++ b/python/cudf/benchmarks/API/bench_dataframe.py @@ -351,7 +351,9 @@ def bench_where(benchmark, dataframe, cond, other): benchmark(dataframe.where, cond, other) -@benchmark_with_object(cls="dataframe", dtype="float", nulls=False, cols=20) +@benchmark_with_object( + cls="dataframe", dtype="float", nulls=False, cols=20, rows=20 +) @pytest.mark.pandas_incompatible def bench_to_cupy(benchmark, dataframe): benchmark(dataframe.to_cupy) diff --git a/python/cudf/benchmarks/API/bench_series.py b/python/cudf/benchmarks/API/bench_series.py index 3516888838e..5b82c38d0d0 100644 --- a/python/cudf/benchmarks/API/bench_series.py +++ b/python/cudf/benchmarks/API/bench_series.py @@ -23,13 +23,17 @@ def bench_series_nsmallest(benchmark, series, n): benchmark(series.nsmallest, n) -@benchmark_with_object(cls="series", dtype="int") +@benchmark_with_object(cls="series", dtype="int", nulls=False) def bench_series_cp_asarray(benchmark, series): - series = series.dropna() benchmark(cupy.asarray, series) -@benchmark_with_object(cls="series", dtype="int") +@benchmark_with_object(cls="series", dtype="int", nulls=False) +@pytest.mark.pandas_incompatible +def bench_to_cupy(benchmark, series): + benchmark(lambda: series.values) + + +@benchmark_with_object(cls="series", dtype="int", nulls=False) def bench_series_values(benchmark, series): - series = series.dropna() benchmark(lambda: series.values) diff --git a/python/cudf/benchmarks/common/config.py b/python/cudf/benchmarks/common/config.py index 0f4e122839c..fc08a03d8d8 100644 --- a/python/cudf/benchmarks/common/config.py +++ b/python/cudf/benchmarks/common/config.py @@ -66,5 +66,5 @@ def pytest_sessionfinish(session, exitstatus): NUM_ROWS = [10, 20] NUM_COLS = [1, 6, 20] else: - NUM_ROWS = [100, 10_000, 1_000_000] - NUM_COLS = [1, 6, 20] + NUM_ROWS = [100, 1_000, 10_000, 50_000, 1_000_000] + NUM_COLS = [1, 6, 20, 1_000, 10_000, 50_000] diff --git a/python/cudf/benchmarks/conftest.py b/python/cudf/benchmarks/conftest.py index 7561bdc41b4..339311be744 100644 --- a/python/cudf/benchmarks/conftest.py +++ b/python/cudf/benchmarks/conftest.py @@ -45,7 +45,6 @@ """ import os -import string import sys import pytest_cases @@ -83,14 +82,8 @@ def axis(request): for dtype, column_generator in column_generators.items(): def make_dataframe(nr, nc, column_generator=column_generator): - assert nc <= len(string.ascii_lowercase), ( - "make_dataframe only supports a maximum of 26 columns" - ) return cudf.DataFrame( - { - f"{string.ascii_lowercase[i]}": column_generator(nr) - for i in range(nc) - } + {f"col{i}": column_generator(nr) for i in range(nc)} ) for nr in NUM_ROWS: diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index f4f8b4a6193..10dfd19230a 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal import cupy -import cupy as cp import numpy import numpy as np import pyarrow as pa @@ -591,7 +590,7 @@ def to_cupy( if ( not copy and col.dtype.kind in {"i", "u", "f", "b"} - and cp.can_cast(col.dtype, final_dtype) + and cupy.can_cast(col.dtype, final_dtype) ): if col.has_nulls(): if na_value is not None: @@ -604,7 +603,7 @@ def to_cupy( dtype, na_value, ) - return cp.asarray(col, dtype=final_dtype).reshape((-1, 1)) + return cupy.asarray(col, dtype=final_dtype).reshape((-1, 1)) return self._to_array( lambda col: col.values, cupy, diff --git a/python/cudf/cudf/core/single_column_frame.py b/python/cudf/cudf/core/single_column_frame.py index fb83fc6eeee..f6158727a04 100644 --- a/python/cudf/cudf/core/single_column_frame.py +++ b/python/cudf/cudf/core/single_column_frame.py @@ -139,25 +139,11 @@ def to_cupy( ------- cupy.ndarray """ - col = self._column - final_dtype = ( - col.dtype if dtype is None else dtype - ) # some types do not support | operator - if ( - not copy - and col.dtype.kind in {"i", "u", "f", "b"} - and cp.can_cast(col.dtype, final_dtype) - ): - if col.has_nulls(): - if na_value is not None: - col = col.fillna(na_value) - else: - return super().to_cupy( - dtype=dtype, copy=copy, na_value=na_value - ) - return cp.asarray(col, dtype=final_dtype) - - return super().to_cupy(dtype=dtype, copy=copy, na_value=na_value) + return ( + super() + .to_cupy(dtype=dtype, copy=copy, na_value=na_value) + .reshape(len(self), order="F") + ) @property # type: ignore @_performance_tracking diff --git a/python/pylibcudf/pylibcudf/reshape.pxd b/python/pylibcudf/pylibcudf/reshape.pxd index 1e7aa957c1a..efb7217f0e8 100644 --- a/python/pylibcudf/pylibcudf/reshape.pxd +++ b/python/pylibcudf/pylibcudf/reshape.pxd @@ -1,5 +1,6 @@ # Copyright (c) 2024-2025, NVIDIA CORPORATION. +from libc.stddef cimport size_t from libc.stdint cimport uintptr_t from pylibcudf.libcudf.types cimport size_type @@ -18,6 +19,6 @@ cpdef Table tile(Table source_table, size_type count) cpdef void table_to_array( Table input_table, uintptr_t ptr, - size_type size, + size_t size, Stream stream=* ) diff --git a/python/pylibcudf/pylibcudf/reshape.pyx b/python/pylibcudf/pylibcudf/reshape.pyx index 7ece4b105ae..df84c6858b3 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyx +++ b/python/pylibcudf/pylibcudf/reshape.pyx @@ -1,5 +1,6 @@ # Copyright (c) 2024-2025, NVIDIA CORPORATION. +from libc.stddef cimport size_t from libc.stdint cimport uintptr_t from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -81,11 +82,12 @@ cpdef Table tile(Table source_table, size_type count): cpdef void table_to_array( Table input_table, uintptr_t ptr, - size_type size, + size_t size, Stream stream=None ): """ Copy a table into a preallocated column-major device array. + Parameters ---------- input_table : Table @@ -98,7 +100,7 @@ cpdef void table_to_array( stream : Stream | None CUDA stream on which to perform the operation. """ - if size > numeric_limits[size_type].max(): + if size > numeric_limits[size_t].max(): raise ValueError( "Size exceeds the int32_t limit." ) diff --git a/python/pylibcudf/pylibcudf/tests/test_reshape.py b/python/pylibcudf/pylibcudf/tests/test_reshape.py index a9a6e808d96..06c9ebec0f9 100644 --- a/python/pylibcudf/pylibcudf/tests/test_reshape.py +++ b/python/pylibcudf/pylibcudf/tests/test_reshape.py @@ -8,16 +8,6 @@ from pylibcudf.types import TypeId -@pytest.fixture(scope="module") -def np(): - return pytest.importorskip("cupy") - - -@pytest.fixture(scope="module") -def cp(): - return pytest.importorskip("cupy") - - @pytest.fixture(scope="module") def reshape_data(): data = [[1, 2, 3], [4, 5, 6]] @@ -59,8 +49,10 @@ def test_tile(reshape_data, cnt): ("float64", TypeId.FLOAT64), ], ) -def test_table_to_array(dtype, type_id, np, cp): - arrow_type = pa.from_numpy_dtype(getattr(np, dtype)) +def test_table_to_array(dtype, type_id): + import cupy as cp + + arrow_type = pa.from_numpy_dtype(getattr(cp, dtype)) arrs = [ pa.array([1, 2, 3], type=arrow_type), pa.array([4, 5, 6], type=arrow_type), From 03ec7775ae69be28ddd8de714bf3afefcfc0f829 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 16:24:31 -0400 Subject: [PATCH 06/12] address review --- python/cudf/cudf/core/frame.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 10dfd19230a..ccbec0af854 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -569,14 +569,14 @@ def to_cupy( ) ): if dtype is None: - dtype = np.dtype(self._columns[0].dtype) + dtype = self._columns[0].dtype shape = (len(self), self._num_columns) out = cupy.empty(shape, dtype=dtype, order="F") - table = self.to_pylibcudf()[0] - if isinstance(table, plc.Column): - table = plc.Table([table]) + table = plc.Table( + [col.to_pylibcudf(mode="read") for col in self._columns] + ) plc.reshape.table_to_array( table, out.data.ptr, From e83c07f11dd4c3070528c4d540a63a8ba6709c20 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 16:30:45 -0400 Subject: [PATCH 07/12] address review --- python/pylibcudf/pylibcudf/tests/test_reshape.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pylibcudf/pylibcudf/tests/test_reshape.py b/python/pylibcudf/pylibcudf/tests/test_reshape.py index 06c9ebec0f9..3802958e5c4 100644 --- a/python/pylibcudf/pylibcudf/tests/test_reshape.py +++ b/python/pylibcudf/pylibcudf/tests/test_reshape.py @@ -1,5 +1,6 @@ # Copyright (c) 2024-2025, NVIDIA CORPORATION. +import cupy as cp import pyarrow as pa import pytest from utils import assert_column_eq, assert_table_eq @@ -50,8 +51,6 @@ def test_tile(reshape_data, cnt): ], ) def test_table_to_array(dtype, type_id): - import cupy as cp - arrow_type = pa.from_numpy_dtype(getattr(cp, dtype)) arrs = [ pa.array([1, 2, 3], type=arrow_type), From 614492fbc809dff43bb9421e0af0969a1bef4ca8 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 16:50:31 -0400 Subject: [PATCH 08/12] address review --- python/cudf/benchmarks/common/config.py | 4 ++-- python/pylibcudf/pylibcudf/reshape.pyx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf/benchmarks/common/config.py b/python/cudf/benchmarks/common/config.py index fc08a03d8d8..0f4e122839c 100644 --- a/python/cudf/benchmarks/common/config.py +++ b/python/cudf/benchmarks/common/config.py @@ -66,5 +66,5 @@ def pytest_sessionfinish(session, exitstatus): NUM_ROWS = [10, 20] NUM_COLS = [1, 6, 20] else: - NUM_ROWS = [100, 1_000, 10_000, 50_000, 1_000_000] - NUM_COLS = [1, 6, 20, 1_000, 10_000, 50_000] + NUM_ROWS = [100, 10_000, 1_000_000] + NUM_COLS = [1, 6, 20] diff --git a/python/pylibcudf/pylibcudf/reshape.pyx b/python/pylibcudf/pylibcudf/reshape.pyx index df84c6858b3..f97aede6e59 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyx +++ b/python/pylibcudf/pylibcudf/reshape.pyx @@ -102,7 +102,7 @@ cpdef void table_to_array( """ if size > numeric_limits[size_t].max(): raise ValueError( - "Size exceeds the int32_t limit." + "Size exceeds the size_t limit." ) stream = _get_stream(stream) From eee909ed0ec004b0e420717a916f4a6b6ce5bbe9 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 16:54:55 -0400 Subject: [PATCH 09/12] docs --- python/pylibcudf/pylibcudf/reshape.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibcudf/pylibcudf/reshape.pyx b/python/pylibcudf/pylibcudf/reshape.pyx index f97aede6e59..0ebe61af713 100644 --- a/python/pylibcudf/pylibcudf/reshape.pyx +++ b/python/pylibcudf/pylibcudf/reshape.pyx @@ -94,7 +94,7 @@ cpdef void table_to_array( A table with fixed-width, non-nullable columns of the same type. ptr : uintptr_t A device pointer to the beginning of the output buffer. - size : size_type + size : size_t The total number of bytes available at `ptr`. Must be at least `num_rows * num_columns * sizeof(dtype)`. stream : Stream | None From 3db48d8c94f1340efb90e03303457f46f1df4a30 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 22:01:41 -0400 Subject: [PATCH 10/12] doc failure and benchmarks failures --- python/cudf/benchmarks/conftest.py | 9 +++- python/cudf/cudf/core/index.py | 66 ++++++++++++++++++++++++++++ python/cudf/cudf/tests/test_index.py | 7 +++ 3 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/cudf/benchmarks/conftest.py b/python/cudf/benchmarks/conftest.py index 339311be744..7561bdc41b4 100644 --- a/python/cudf/benchmarks/conftest.py +++ b/python/cudf/benchmarks/conftest.py @@ -45,6 +45,7 @@ """ import os +import string import sys import pytest_cases @@ -82,8 +83,14 @@ def axis(request): for dtype, column_generator in column_generators.items(): def make_dataframe(nr, nc, column_generator=column_generator): + assert nc <= len(string.ascii_lowercase), ( + "make_dataframe only supports a maximum of 26 columns" + ) return cudf.DataFrame( - {f"col{i}": column_generator(nr) for i in range(nc)} + { + f"{string.ascii_lowercase[i]}": column_generator(nr) + for i in range(nc) + } ) for nr in NUM_ROWS: diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 04cd753e7d6..48c0bb5bf5d 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -270,6 +270,72 @@ def _from_data(cls, data: MutableMapping, name: Any = no_default) -> Self: def _from_data_like_self(self, data: MutableMapping) -> Self: return _index_from_data(data, self.name) + @_performance_tracking + def to_pylibcudf(self, copy=False) -> tuple[plc.Column, dict]: + """ + Convert this Index to a pylibcudf.Column. + + Parameters + ---------- + copy : bool + Whether or not to generate a new copy of the underlying device data + + Returns + ------- + pylibcudf.Column + A new pylibcudf.Column referencing the same data. + dict + Dict of metadata (includes name) + + Notes + ----- + User requests to convert to pylibcudf must assume that the + data may be modified afterwards. + """ + if copy: + raise NotImplementedError("copy=True is not supported") + metadata = {"name": self.name} + return self._column.to_pylibcudf(mode="write"), metadata + + @classmethod + @_performance_tracking + def from_pylibcudf( + cls, col: plc.Column, metadata: dict | None = None + ) -> Self: + """ + Create a Index from a pylibcudf.Column. + + Parameters + ---------- + col : pylibcudf.Column + The input Column. + + Returns + ------- + pylibcudf.Column + A new pylibcudf.Column referencing the same data. + + Notes + ----- + This function will generate an Index which contains a Column + pointing to the provided pylibcudf Column. It will directly access + the data and mask buffers of the pylibcudf Column, so the newly created + object is not tied to the lifetime of the original pylibcudf.Column. + """ + name = None + if metadata is not None: + if not ( + isinstance(metadata, dict) + and len(metadata) == 1 + and set(metadata) == {"name"} + ): + raise ValueError("Metadata dict must only contain a name") + name = metadata.get("name") + return cls._from_column( + ColumnBase.from_pylibcudf(col, data_ptr_exposed=True), + name=name, + ) + @classmethod @_performance_tracking def from_arrow(cls, obj: pa.Array) -> Index | cudf.MultiIndex: diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 2e0a8e4149e..cdd9d121ee7 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -3347,3 +3347,10 @@ def test_categoricalindex_from_codes(ordered, name): name=name, ) assert_eq(result, expected) + + +def test_roundtrip_index_plc_column(): + index = cudf.Index([1]) + expect = cudf.Index(index) + actual = cudf.Index.from_pylibcudf(*expect.to_pylibcudf()) + assert_eq(expect, actual) From 5bca32d05e38654b3b1471a03df7a522e5eac0b5 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 22:08:18 -0400 Subject: [PATCH 11/12] docs --- python/cudf/cudf/core/index.py | 2 ++ python/cudf/cudf/core/series.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 46ddce6ca82..a037c790911 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -367,6 +367,8 @@ def from_pylibcudf( ------- pylibcudf.Column A new pylibcudf.Column referencing the same data. + metadata : dict | None + The Index metadata. Notes ----- diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index a098840346a..44aa54175b5 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -3820,6 +3820,8 @@ def from_pylibcudf( ---------- col : pylibcudf.Column The input Column. + metadata : dict | None + The Series metadata. Returns ------- From ea5e0cdddb7bdca6ec2dd70aa6b137e77fe80baf Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 14 May 2025 23:06:56 -0400 Subject: [PATCH 12/12] docs --- docs/cudf/source/user_guide/api_docs/index_objects.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/cudf/source/user_guide/api_docs/index_objects.rst b/docs/cudf/source/user_guide/api_docs/index_objects.rst index 9c84f206010..ddf6c69af7d 100644 --- a/docs/cudf/source/user_guide/api_docs/index_objects.rst +++ b/docs/cudf/source/user_guide/api_docs/index_objects.rst @@ -110,6 +110,8 @@ Conversion Index.to_frame Index.to_pandas Index.to_dlpack + Index.to_pylibcudf + Index.from_pylibcudf Index.from_pandas Index.from_arrow