Skip to content

Commit

Permalink
Support dask_expr migration into dask.dataframe (#17704)
Browse files Browse the repository at this point in the history
Follow up to #17558
This PR cleans up some imports and provides support for both `dask:2024.12.1` and `dask:main` (in which `dask_expr` has been moved into the `dask.dataframe` module).

See also: rapidsai/dask-cuda#1424

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Peter Andreas Entschev (https://github.com/pentschev)
  - Bradley Dice (https://github.com/bdice)

URL: #17704
  • Loading branch information
rjzamora authored Jan 16, 2025
1 parent 3cad1a6 commit 7f2b2ba
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 100 deletions.
7 changes: 1 addition & 6 deletions python/dask_cudf/dask_cudf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
import cudf

from . import backends, io # noqa: F401
from ._expr import collection # noqa: F401
from ._expr.expr import _patch_dask_expr
from ._version import __git_commit__, __version__ # noqa: F401
from .core import DataFrame, Index, Series, _deprecated_api, concat, from_cudf

if not (QUERY_PLANNING_ON := dd._dask_expr_enabled()):
raise ValueError(
"The legacy DataFrame API is not supported in dask_cudf>24.12. "
"Please enable query-planning, or downgrade to dask_cudf<=24.12"
)


def read_csv(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
Expand Down
97 changes: 96 additions & 1 deletion python/dask_cudf/dask_cudf/_expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,96 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.

from packaging.version import Version

import dask

if Version(dask.__version__) > Version("2024.12.1"):
import dask.dataframe.dask_expr._shuffle as _shuffle_module
from dask.dataframe.dask_expr import (
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
from_dict,
get_collection_type,
new_collection,
)
from dask.dataframe.dask_expr._cumulative import (
CumulativeBlockwise,
)
from dask.dataframe.dask_expr._expr import (
Elemwise,
Expr,
RenameAxis,
VarColumns,
)
from dask.dataframe.dask_expr._groupby import (
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask.dataframe.dask_expr._reductions import (
Reduction,
Var,
)
from dask.dataframe.dask_expr._util import (
_convert_to_list,
_raise_if_object_series,
is_scalar,
)
from dask.dataframe.dask_expr.io.io import (
FusedIO,
FusedParquetIO,
)
from dask.dataframe.dask_expr.io.parquet import (
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)
else:
import dask_expr._shuffle as _shuffle_module # noqa: F401
from dask_expr import ( # noqa: F401
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
from_dict,
get_collection_type,
new_collection,
)
from dask_expr._cumulative import CumulativeBlockwise # noqa: F401
from dask_expr._expr import ( # noqa: F401
Elemwise,
Expr,
RenameAxis,
VarColumns,
)
from dask_expr._groupby import ( # noqa: F401
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._reductions import Reduction, Var # noqa: F401
from dask_expr._util import ( # noqa: F401
_convert_to_list,
_raise_if_object_series,
is_scalar,
)
from dask_expr.io.io import FusedIO, FusedParquetIO # noqa: F401
from dask_expr.io.parquet import ( # noqa: F401
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)

from dask.dataframe import _dask_expr_enabled

if not _dask_expr_enabled():
raise ValueError(
"The legacy DataFrame API is not supported for RAPIDS >24.12. "
"The 'dataframe.query-planning' config must be True or None."
)
20 changes: 10 additions & 10 deletions python/dask_cudf/dask_cudf/_expr/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
import warnings
from functools import cached_property

from dask_expr import (
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
get_collection_type,
)
from dask_expr._collection import new_collection
from dask_expr._util import _raise_if_object_series

from dask import config
from dask.dataframe.core import is_dataframe_like
from dask.dataframe.dispatch import get_parallel_type
from dask.typing import no_default

import cudf

from dask_cudf._expr import (
DXDataFrame,
DXIndex,
DXSeries,
FrameBase,
_raise_if_object_series,
get_collection_type,
new_collection,
)

##
## Custom collection classes
##
Expand Down
18 changes: 12 additions & 6 deletions python/dask_cudf/dask_cudf/_expr/expr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
import functools

import dask_expr._shuffle as _shuffle_module
from dask_expr import new_collection
from dask_expr._cumulative import CumulativeBlockwise
from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns
from dask_expr._reductions import Reduction, Var

from dask.dataframe.dispatch import (
is_categorical_dtype,
make_meta,
Expand All @@ -17,6 +11,18 @@

import cudf

from dask_cudf._expr import (
CumulativeBlockwise,
Elemwise,
Expr,
Reduction,
RenameAxis,
Var,
VarColumns,
_shuffle_module,
new_collection,
)

##
## Custom expressions
##
Expand Down
19 changes: 10 additions & 9 deletions python/dask_cudf/dask_cudf/_expr/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@

import numpy as np
import pandas as pd
from dask_expr._collection import new_collection
from dask_expr._groupby import (
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._util import is_scalar

from dask.dataframe.core import _concat
from dask.dataframe.groupby import Aggregation

from cudf.core.groupby.groupby import _deprecate_collect
from cudf.utils.performance_tracking import _dask_cudf_performance_tracking

from dask_cudf._expr import (
DecomposableGroupbyAggregation,
DXGroupBy,
DXSeriesGroupBy,
GroupbyAggregation,
SingleAggregation,
is_scalar,
new_collection,
)

##
## Fused groupby aggregations
##
Expand Down
61 changes: 15 additions & 46 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,16 +543,6 @@ def to_cudf_dispatch_from_cudf(data, **kwargs):
return data


# Define the "cudf" backend for "legacy" Dask DataFrame
class LegacyCudfBackendEntrypoint(DataFrameBackendEntrypoint):
"""Backend-entrypoint class for legacy Dask-DataFrame
This class is registered under the name "cudf" for the
``dask.dataframe.backends`` entrypoint in ``pyproject.toml``.
This "legacy" backend is only used for CSV support.
"""


# Define the "cudf" backend for expr-based Dask DataFrame
class CudfBackendEntrypoint(DataFrameBackendEntrypoint):
"""Backend-entrypoint class for Dask-Expressions
Expand All @@ -566,20 +556,19 @@ class CudfBackendEntrypoint(DataFrameBackendEntrypoint):
Examples
--------
>>> import dask
>>> import dask_expr as dx
>>> import dask.dataframe as dd
>>> with dask.config.set({"dataframe.backend": "cudf"}):
... ddf = dx.from_dict({"a": range(10)})
... ddf = dd.from_dict({"a": range(10)})
>>> type(ddf._meta)
<class 'cudf.core.dataframe.DataFrame'>
"""

@staticmethod
def to_backend(data, **kwargs):
import dask_expr as dx

from dask_cudf._expr import new_collection
from dask_cudf._expr.expr import ToCudfBackend

return dx.new_collection(ToCudfBackend(data, kwargs))
return new_collection(ToCudfBackend(data, kwargs))

@staticmethod
def from_dict(
Expand All @@ -590,10 +579,10 @@ def from_dict(
columns=None,
constructor=cudf.DataFrame,
):
import dask_expr as dx
from dask_cudf._expr import from_dict

return _default_backend(
dx.from_dict,
from_dict,
data,
npartitions=npartitions,
orient=orient,
Expand All @@ -617,35 +606,15 @@ def read_csv(
storage_options=None,
**kwargs,
):
try:
# TODO: Remove when cudf is pinned to dask>2024.12.0
import dask_expr as dx
from dask_expr.io.csv import ReadCSV
from fsspec.utils import stringify_path

if not isinstance(path, str):
path = stringify_path(path)
return dx.new_collection(
ReadCSV(
path,
dtype_backend=dtype_backend,
storage_options=storage_options,
kwargs=kwargs,
header=header,
dataframe_backend="cudf",
)
)
except ImportError:
# Requires dask>2024.12.0
from dask_cudf.io.csv import read_csv

return read_csv(
path,
*args,
header=header,
storage_options=storage_options,
**kwargs,
)
from dask_cudf.io.csv import read_csv

return read_csv(
path,
*args,
header=header,
storage_options=storage_options,
**kwargs,
)

@staticmethod
def read_json(*args, **kwargs):
Expand Down
35 changes: 14 additions & 21 deletions python/dask_cudf/dask_cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,26 @@

import numpy as np
import pandas as pd
from dask_expr._expr import Elemwise
from dask_expr._util import _convert_to_list
from dask_expr.io.io import FusedIO, FusedParquetIO
from dask_expr.io.parquet import (
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)

from dask._task_spec import Task
from dask._task_spec import List as TaskList, Task
from dask.dataframe.io.parquet.arrow import _filters_to_expression
from dask.dataframe.io.parquet.core import ParquetFunctionWrapper
from dask.tokenize import tokenize
from dask.utils import parse_bytes

try:
# TODO: Remove try/except when dask>2024.11.2
from dask._task_spec import List as TaskList
except ImportError:

def TaskList(*x):
return list(x)


import cudf

from dask_cudf._expr import (
Elemwise,
FragmentWrapper,
FusedIO,
FusedParquetIO,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
_convert_to_list,
new_collection,
)

# Dask-expr imports CudfEngine from this module
from dask_cudf._legacy.io.parquet import CudfEngine
from dask_cudf.core import _deprecated_api
Expand Down Expand Up @@ -698,7 +692,6 @@ def read_parquet_expr(
using the ``read`` key-word argument.
"""

import dask_expr as dx
from fsspec.utils import stringify_path
from pyarrow import fs as pa_fs

Expand Down Expand Up @@ -785,7 +778,7 @@ def read_parquet_expr(
"parquet_file_extension is not supported when using the pyarrow filesystem."
)

return dx.new_collection(
return new_collection(
NoOp(
CudfReadParquetPyarrowFS(
path,
Expand All @@ -806,7 +799,7 @@ def read_parquet_expr(
)
)

return dx.new_collection(
return new_collection(
NoOp(
CudfReadParquetFSSpec(
path,
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ classifiers = [
]

[project.entry-points."dask.dataframe.backends"]
cudf = "dask_cudf.backends:LegacyCudfBackendEntrypoint"
cudf = "dask_cudf.backends:CudfBackendEntrypoint"

[project.entry-points."dask_expr.dataframe.backends"]
cudf = "dask_cudf.backends:CudfBackendEntrypoint"
Expand Down

0 comments on commit 7f2b2ba

Please sign in to comment.