Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic multi-partition GroupBy support to cuDF-Polars #17503

Merged
merged 35 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f0964a6
basic groupby-aggregation support
rjzamora Dec 4, 2024
1329cf1
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Dec 4, 2024
11a03f8
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Dec 4, 2024
a9fa486
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 4, 2024
b1224a0
remove GroupbyTree
rjzamora Dec 4, 2024
385f03a
simplify lower
rjzamora Dec 6, 2024
8956215
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 6, 2024
70b29b2
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Dec 19, 2024
3f04eca
cleanup
rjzamora Dec 19, 2024
e090de5
no cover
rjzamora Dec 19, 2024
24b88f2
tweak error message
rjzamora Dec 19, 2024
161a53b
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Jan 9, 2025
69f6336
update copyright dates
rjzamora Jan 9, 2025
22cebeb
add test coverage for single-partition
rjzamora Jan 11, 2025
45ac8ec
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora Jan 11, 2025
f5205bd
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Jan 27, 2025
a7cd29f
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Jan 29, 2025
ef79e90
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 25, 2025
b8a20e6
formatting
rjzamora Feb 25, 2025
fde4231
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 27, 2025
7d18e7b
add shuffle-based groupby
rjzamora Feb 27, 2025
f21e1cd
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Feb 28, 2025
ccd1029
improve test coverage
rjzamora Feb 28, 2025
309757c
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 3, 2025
75a7257
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 4, 2025
385c68a
break out the decomposition of a single groupby request into a stand-…
rjzamora Mar 4, 2025
7c96482
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 4, 2025
fe3ca7c
break out the decomposition of a single groupby request into a stand-…
rjzamora Mar 4, 2025
16cf883
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 7, 2025
9f9c097
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 11, 2025
f959988
address schema and maintain_order issues
rjzamora Mar 11, 2025
bca71d6
use lawrences suggestions
rjzamora Mar 11, 2025
568a3f0
Merge branch 'branch-25.04' into cudf-polars-multi-groupby
rjzamora Mar 11, 2025
ef10e25
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Mar 11, 2025
7f90ded
address small code-review comments
rjzamora Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Callback for the polars collect function to execute on device."""
Expand Down Expand Up @@ -233,6 +233,8 @@ def validate_config_options(config: dict) -> None:
unsupported = config.get("executor_options", {}).keys() - {
"max_rows_per_partition",
"parquet_blocksize",
"cardinality_factor",
"groupby_n_ary",
}
else:
unsupported = config.get("executor_options", {}).keys()
Expand Down
26 changes: 25 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,11 +855,19 @@ def __init__(self, polars_groupby_options: Any):
__slots__ = (
"agg_infos",
"agg_requests",
"config_options",
"keys",
"maintain_order",
"options",
)
_non_child = ("schema", "keys", "agg_requests", "maintain_order", "options")
_non_child = (
"schema",
"keys",
"agg_requests",
"maintain_order",
"options",
"config_options",
)
keys: tuple[expr.NamedExpr, ...]
"""Grouping keys."""
agg_requests: tuple[expr.NamedExpr, ...]
Expand All @@ -868,6 +876,8 @@ def __init__(self, polars_groupby_options: Any):
"""Preserve order in groupby."""
options: GroupbyOptions
"""Arbitrary options."""
config_options: dict[str, Any]
"""GPU-specific configuration options"""

def __init__(
self,
Expand All @@ -876,13 +886,15 @@ def __init__(
agg_requests: Sequence[expr.NamedExpr],
maintain_order: bool, # noqa: FBT001
options: Any,
config_options: dict[str, Any],
df: IR,
):
self.schema = schema
self.keys = tuple(keys)
self.agg_requests = tuple(agg_requests)
self.maintain_order = maintain_order
self.options = self.GroupbyOptions(options)
self.config_options = config_options
self.children = (df,)
if self.options.rolling:
raise NotImplementedError(
Expand All @@ -900,6 +912,18 @@ def __init__(
self.AggInfos(self.agg_requests),
)

def get_hashable(self) -> Hashable:
"""Hashable representation of the node."""
return (
type(self),
tuple(self.schema.items()),
self.keys,
self.maintain_order,
self.options,
json.dumps(self.config_options),
self.children,
)

@staticmethod
def check_agg(agg: expr.Expr) -> int:
"""
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _(
aggs,
node.maintain_order,
node.options,
translator.config.config.copy(),
inp,
)

Expand Down
304 changes: 304 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel GroupBy Logic."""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any

import pylibcudf as plc

from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr, UnaryFunction
from cudf_polars.dsl.ir import GroupBy, Select
from cudf_polars.dsl.traversal import traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
from cudf_polars.experimental.shuffle import Shuffle

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.expr import Expr
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer


# Supported multi-partition aggregations
_GB_AGG_SUPPORTED = ("sum", "count", "mean")


def combine(
*decompositions: tuple[NamedExpr, list[NamedExpr], list[NamedExpr]],
) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]:
"""
Combine multiple groupby-aggregation decompositions.

Parameters
----------
decompositions
Packed sequence of `decompose` results.

Returns
-------
Unified groupby-aggregation decomposition.
"""
selections, aggregations, reductions = zip(*decompositions, strict=True)
assert all(isinstance(ne, NamedExpr) for ne in selections)
return (
list(selections),
list(itertools.chain.from_iterable(aggregations)),
list(itertools.chain.from_iterable(reductions)),
)


def decompose(
name: str, expr: Expr
) -> tuple[NamedExpr, list[NamedExpr], list[NamedExpr]]:
"""
Decompose a groupby-aggregation expression.

Parameters
----------
name
Output schema name.
expr
The aggregation expression for a single column.

Returns
-------
Tuple containing the `NamedExpr`s for each of the
three parallel-aggregation phases: (1) selection,
(2) initial aggregation, and (3) reduction.
"""
dtype = expr.dtype
expr = expr.children[0] if isinstance(expr, Cast) else expr

unary_op: list[Any] = []
if isinstance(expr, UnaryFunction) and expr.is_pointwise:
# TODO: Handle multiple/sequential unary ops
unary_op = [expr.name, expr.options]
expr = expr.children[0]

def _wrap_unary(select):
# Helper function to wrap the final selection
# in a UnaryFunction (when necessary)
if unary_op:
return UnaryFunction(select.dtype, *unary_op, select)
return select

if isinstance(expr, Len):
selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
aggregation = [NamedExpr(name, expr)]
reduction = [
NamedExpr(
name,
# Sum reduction may require casting.
# Do it for all cases to be safe (for now)
Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))),
)
]
return selection, aggregation, reduction
if isinstance(expr, Agg):
if expr.name in ("sum", "count"):
selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
aggregation = [NamedExpr(name, expr)]
reduction = [
NamedExpr(
name,
# Sum reduction may require casting.
# Do it for all cases to be safe (for now)
Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))),
)
]
return selection, aggregation, reduction
elif expr.name == "mean":
(child,) = expr.children
(sum, count), aggregations, reductions = combine(
# TODO: Avoid possibility of a name collision
# (even though the likelihood is small)
decompose(f"{name}__mean_sum", Agg(dtype, "sum", None, child)),
decompose(f"{name}__mean_count", Len(dtype)),
)
selection = NamedExpr(
name,
_wrap_unary(
BinOp(
dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value
)
),
)
return selection, aggregations, reductions
else:
raise NotImplementedError(
"GroupBy does not support multiple partitions "
f"for this aggregation type:\n{type(expr)}\n"
f"Only {_GB_AGG_SUPPORTED} are supported."
)
else: # pragma: no cover
# Unsupported expression
raise NotImplementedError(
f"GroupBy does not support multiple partitions for this expression:\n{expr}"
)


@lower_ir_node.register(GroupBy)
def _(
ir: GroupBy, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Extract child partitioning
child, partition_info = rec(ir.children[0])

# Handle single-partition case
if partition_info[child].count == 1:
single_part_node = ir.reconstruct([child])
partition_info[single_part_node] = partition_info[child]
return single_part_node, partition_info

# Check group-by keys
if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])):
raise NotImplementedError(
f"GroupBy does not support multiple partitions for keys:\n{ir.keys}"
) # pragma: no cover

# Check if we are dealing with any high-cardinality columns
post_aggregation_count = 1 # Default tree reduction
groupby_key_columns = [ne.name for ne in ir.keys]
cardinality_factor = {
c: min(f, 1.0)
for c, f in ir.config_options.get("executor_options", {})
.get("cardinality_factor", {})
.items()
if c in groupby_key_columns
}
if cardinality_factor:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "cardinality factor" logic is new, and can be pulled out of the PR if necessary. However, we do need a mechanism to trigger shuffle-based groupby aggregations in practice.

# The `cardinality_factor` dictionary can be used
# to specify a mapping between column names and
# cardinality "factors". Each factor estimates the
# fractional number of unique values in the column.
# Each value should be in the range (0, 1].
child_count = partition_info[child].count
post_aggregation_count = max(
int(max(cardinality_factor.values()) * child_count),
1,
)

# Decompose the aggregation requests into three distinct phases
selection_exprs, piecewise_exprs, reduction_exprs = combine(
*(decompose(agg.name, agg.value) for agg in ir.agg_requests)
)

# Partition-wise groupby operation
pwise_schema = {k.name: k.value.dtype for k in ir.keys} | {
k.name: k.value.dtype for k in piecewise_exprs
}
gb_pwise = GroupBy(
pwise_schema,
ir.keys,
piecewise_exprs,
ir.maintain_order,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can support maintain_order == True (at least easily). So perhaps we should raise in that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we support it for the tree reduction, but not for a shuffle-based reduction, right? The tree-reduction tasks should be ordered appropriately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah true.

ir.options,
ir.config_options,
child,
)
child_count = partition_info[child].count
partition_info[gb_pwise] = PartitionInfo(count=child_count)

# Add Shuffle node if necessary
gb_inter: GroupBy | Shuffle = gb_pwise
if post_aggregation_count > 1:
if ir.maintain_order: # pragma: no cover
raise NotImplementedError(
"maintain_order not supported for multiple output partitions."
)

shuffle_options: dict[str, Any] = {}
gb_inter = Shuffle(
pwise_schema,
ir.keys,
shuffle_options,
gb_pwise,
)
partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count)

# Tree reduction if post_aggregation_count==1
# (Otherwise, this is another partition-wise op)
gb_reduce = GroupBy(
{k.name: k.value.dtype for k in ir.keys}
| {k.name: k.value.dtype for k in reduction_exprs},
ir.keys,
reduction_exprs,
ir.maintain_order,
ir.options,
ir.config_options,
gb_inter,
)
partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count)

# Final Select phase
aggregated = {ne.name: ne for ne in selection_exprs}
new_node = Select(
ir.schema,
[
# Select the aggregated data or the original column
aggregated.get(name, NamedExpr(name, Col(dtype, name)))
for name, dtype in ir.schema.items()
],
False, # noqa: FBT003
gb_reduce,
)
partition_info[new_node] = PartitionInfo(count=post_aggregation_count)
return new_node, partition_info


def _tree_node(do_evaluate, batch, *args):
return do_evaluate(*args, _concat(batch))


@generate_ir_tasks.register(GroupBy)
def _(
ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
(child,) = ir.children
child_count = partition_info[child].count
child_name = get_key_name(child)
output_count = partition_info[ir].count

if output_count == child_count:
return {
key: (
ir.do_evaluate,
*ir._non_child_args,
(child_name, i),
)
for i, key in enumerate(partition_info[ir].keys(ir))
}
elif output_count != 1: # pragma: no cover
raise ValueError(f"Expected single partition, got {output_count}")

# Simple N-ary tree reduction
j = 0
n_ary = ir.config_options.get("executor_options", {}).get("groupby_n_ary", 32)
graph: MutableMapping[Any, Any] = {}
name = get_key_name(ir)
keys: list[Any] = [(child_name, i) for i in range(child_count)]
while len(keys) > n_ary:
new_keys: list[Any] = []
for i, k in enumerate(range(0, len(keys), n_ary)):
batch = keys[k : k + n_ary]
graph[(name, j, i)] = (
_tree_node,
ir.do_evaluate,
batch,
*ir._non_child_args,
)
new_keys.append((name, j, i))
j += 1
keys = new_keys
graph[(name, 0)] = (
_tree_node,
ir.do_evaluate,
keys,
*ir._non_child_args,
)
return graph
Loading
Loading