Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic multi-partition GroupBy support to cuDF-Polars #17503

Merged
merged 35 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f0964a6
basic groupby-aggregation support
rjzamora Dec 4, 2024
1329cf1
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Dec 4, 2024
11a03f8
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Dec 4, 2024
a9fa486
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 4, 2024
b1224a0
remove GroupbyTree
rjzamora Dec 4, 2024
385f03a
simplify lower
rjzamora Dec 6, 2024
8956215
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 6, 2024
70b29b2
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 19, 2024
3f04eca
cleanup
rjzamora Dec 19, 2024
e090de5
no cover
rjzamora Dec 19, 2024
24b88f2
tweak error message
rjzamora Dec 19, 2024
161a53b
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Jan 9, 2025
69f6336
update copyright dates
rjzamora Jan 9, 2025
22cebeb
add test coverage for single-partition
rjzamora Jan 11, 2025
45ac8ec
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Jan 11, 2025
f5205bd
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Jan 27, 2025
a7cd29f
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Jan 29, 2025
ef79e90
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 25, 2025
b8a20e6
formatting
rjzamora Feb 25, 2025
fde4231
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 27, 2025
7d18e7b
add shuffle-based groupby
rjzamora Feb 27, 2025
f21e1cd
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 28, 2025
ccd1029
improve test coverage
rjzamora Feb 28, 2025
309757c
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 3, 2025
75a7257
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 4, 2025
385c68a
break out the decomposition of a single groupby request into a stand-…
rjzamora Mar 4, 2025
7c96482
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 4, 2025
fe3ca7c
break out the decomposition of a single groupby request into a stand-…
rjzamora Mar 4, 2025
16cf883
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 7, 2025
9f9c097
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 11, 2025
f959988
address schema and maintain_order issues
rjzamora Mar 11, 2025
bca71d6
use lawrences suggestions
rjzamora Mar 11, 2025
568a3f0
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 11, 2025
ef10e25
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 11, 2025
7f90ded
address small code-review comments
rjzamora Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel GroupBy Logic."""

from __future__ import annotations

import operator
from functools import reduce
from typing import TYPE_CHECKING, Any

import pylibcudf as plc

from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr
from cudf_polars.dsl.ir import GroupBy, Select
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.expr import Expr
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer


class GroupByTree(GroupBy):
"""Groupby tree-reduction operation."""


_GB_AGG_SUPPORTED = ("sum", "count", "mean")


def _single_fallback(
ir: IR,
children: tuple[IR],
partition_info: MutableMapping[IR, PartitionInfo],
unsupported_agg: Expr | None = None,
):
if any(partition_info[child].count > 1 for child in children): # pragma: no cover
msg = f"Class {type(ir)} does not support multiple partitions."
if unsupported_agg:
msg = msg[:-1] + f" with {unsupported_agg} expression."
raise NotImplementedError(msg)

new_node = ir.reconstruct(children)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info


@lower_ir_node.register(GroupBy)
def _(
ir: GroupBy, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

if partition_info[children[0]].count == 1:
# Single partition
return _single_fallback(ir, children, partition_info)

# Check that we are grouping on element-wise
# keys (is this already guaranteed?)
for ne in ir.keys:
if not isinstance(ne.value, Col): # pragma: no cover
return _single_fallback(ir, children, partition_info)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by elementwise keys? It's certainly not the case that we always group on columns. But I think it is the case that the group keys (if expressions) are trivially elementwise (e.g. a + b as a key is fine, but a.unique() or a.sort() is not)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I'm being extra cautious by requiring the keys to be Col. This comment is essentially asking: "can we drop this check altogether? ie. Will the keys always be element-wise?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened pola-rs/polars#20152 as well


name_map: MutableMapping[str, Any] = {}
agg_tree: Cast | Agg | None = None
agg_requests_pwise = [] # Partition-wise requests
agg_requests_tree = [] # Tree-node requests

for ne in ir.agg_requests:
name = ne.name
agg: Expr = ne.value
dtype = agg.dtype
agg = agg.children[0] if isinstance(agg, Cast) else agg
if isinstance(agg, Len):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", None, Col(dtype, name)),
),
)
)
elif isinstance(agg, Agg):
if agg.name not in _GB_AGG_SUPPORTED:
return _single_fallback(ir, children, partition_info, agg)

if agg.name in ("sum", "count"):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", agg.options, Col(dtype, name)),
),
)
)
elif agg.name == "mean":
name_map[name] = {agg.name: {}}
for sub in ["sum", "count"]:
# Partwise
tmp_name = f"{name}__{sub}"
name_map[name][agg.name][sub] = tmp_name
agg_pwise = Agg(dtype, sub, agg.options, *agg.children)
agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise))
# Tree
child = Col(dtype, tmp_name)
agg_tree = Agg(dtype, "sum", agg.options, child)
agg_requests_tree.append(NamedExpr(tmp_name, agg_tree))
else:
# Unsupported
return _single_fallback(
ir, children, partition_info, agg
) # pragma: no cover

gb_pwise = GroupBy(
ir.schema,
ir.keys,
agg_requests_pwise,
ir.maintain_order,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can support maintain_order == True (at least easily). So perhaps we should raise in that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we support it for the tree reduction, but not for a shuffle-based reduction, right? The tree-reduction tasks should be ordered appropriately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah true.

ir.options,
*children,
)
child_count = partition_info[children[0]].count
partition_info[gb_pwise] = PartitionInfo(count=child_count)

gb_tree = GroupByTree(
ir.schema,
ir.keys,
agg_requests_tree,
ir.maintain_order,
ir.options,
gb_pwise,
)
partition_info[gb_tree] = PartitionInfo(count=1)

schema = ir.schema
output_exprs = []
for name, dtype in schema.items():
agg_mapping = name_map.get(name, None)
if agg_mapping is None:
output_exprs.append(NamedExpr(name, Col(dtype, name)))
elif "mean" in agg_mapping:
mean_cols = agg_mapping["mean"]
output_exprs.append(
NamedExpr(
name,
BinOp(
dtype,
plc.binaryop.BinaryOperator.DIV,
Col(dtype, mean_cols["sum"]),
Col(dtype, mean_cols["count"]),
),
)
)
should_broadcast: bool = False
new_node = Select(
schema,
output_exprs,
should_broadcast,
gb_tree,
)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info


def _tree_node(do_evaluate, batch, *args):
return do_evaluate(*args, _concat(batch))


@generate_ir_tasks.register(GroupByTree)
def _(
ir: GroupByTree, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
child = ir.children[0]
child_count = partition_info[child].count
child_name = get_key_name(child)
name = get_key_name(ir)

# Simple tree reduction.
j = 0
graph: MutableMapping[Any, Any] = {}
split_every = 32
keys: list[Any] = [(child_name, i) for i in range(child_count)]
while len(keys) > split_every:
new_keys: list[Any] = []
for i, k in enumerate(range(0, len(keys), split_every)):
batch = keys[k : k + split_every]
graph[(name, j, i)] = (
_tree_node,
ir.do_evaluate,
batch,
*ir._non_child_args,
)
new_keys.append((name, j, i))
j += 1
keys = new_keys
graph[(name, 0)] = (
_tree_node,
ir.do_evaluate,
keys,
*ir._non_child_args,
)
return graph
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.groupby
import cudf_polars.experimental.io # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Projection, Union
from cudf_polars.dsl.ir import IR, Cache, GroupBy, Projection, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import (
Expand Down Expand Up @@ -243,5 +244,6 @@ def _generate_ir_tasks_pwise(
}


generate_ir_tasks.register(GroupBy, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise)
53 changes: 53 additions & 0 deletions python/cudf_polars/tests/experimental/test_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 4},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"x": range(150),
"y": ["cat", "dog", "fish"] * 50,
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 30,
}
)


@pytest.mark.parametrize("op", ["sum", "mean", "len"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby(df, engine, op, keys):
q = getattr(df.group_by(*keys), op)()
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby_agg(df, engine, op, keys):
q = df.group_by(*keys).agg(getattr(pl.col("x"), op)())
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


def test_groupby_raises(df, engine):
q = df.group_by("y").median()
with pytest.raises(
pl.exceptions.ComputeError,
match="NotImplementedError",
):
assert_gpu_result_equal(q, engine=engine, check_row_order=False)
Loading