From f0964a633f9f8258d6e75b0d217508e76f0666a9 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 3 Dec 2024 16:29:32 -0800 Subject: [PATCH 01/30] basic groupby-aggregation support --- .../cudf_polars/experimental/groupby.py | 210 ++++++++++++++++++ .../cudf_polars/experimental/parallel.py | 4 +- .../tests/experimental/test_groupby.py | 53 +++++ 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 python/cudf_polars/cudf_polars/experimental/groupby.py create mode 100644 python/cudf_polars/tests/experimental/test_groupby.py diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py new file mode 100644 index 00000000000..2f151dc8886 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Parallel GroupBy Logic.""" + +from __future__ import annotations + +import operator +from functools import reduce +from typing import TYPE_CHECKING, Any + +import pylibcudf as plc + +from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr +from cudf_polars.dsl.ir import GroupBy, Select +from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name +from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from cudf_polars.dsl.expr import Expr + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.parallel import LowerIRTransformer + + +class GroupByTree(GroupBy): + """Groupby tree-reduction operation.""" + + +_GB_AGG_SUPPORTED = ("sum", "count", "mean") + + +def _single_fallback( + ir: IR, + children: tuple[IR], + partition_info: MutableMapping[IR, PartitionInfo], + unsupported_agg: Expr | None = None, +): + if any(partition_info[child].count > 1 for child in children): # pragma: no cover + msg = f"Class {type(ir)} does not support multiple partitions." + if unsupported_agg: + msg = msg[:-1] + f" with {unsupported_agg} expression." + raise NotImplementedError(msg) + + new_node = ir.reconstruct(children) + partition_info[new_node] = PartitionInfo(count=1) + return new_node, partition_info + + +@lower_ir_node.register(GroupBy) +def _( + ir: GroupBy, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) + partition_info = reduce(operator.or_, _partition_info) + + if partition_info[children[0]].count == 1: + # Single partition + return _single_fallback(ir, children, partition_info) + + # Check that we are grouping on element-wise + # keys (is this already guaranteed?) + for ne in ir.keys: + if not isinstance(ne.value, Col): # pragma: no cover + return _single_fallback(ir, children, partition_info) + + name_map: MutableMapping[str, Any] = {} + agg_tree: Cast | Agg | None = None + agg_requests_pwise = [] # Partition-wise requests + agg_requests_tree = [] # Tree-node requests + + for ne in ir.agg_requests: + name = ne.name + agg: Expr = ne.value + dtype = agg.dtype + agg = agg.children[0] if isinstance(agg, Cast) else agg + if isinstance(agg, Len): + agg_requests_pwise.append(ne) + agg_requests_tree.append( + NamedExpr( + name, + Cast( + dtype, + Agg(dtype, "sum", None, Col(dtype, name)), + ), + ) + ) + elif isinstance(agg, Agg): + if agg.name not in _GB_AGG_SUPPORTED: + return _single_fallback(ir, children, partition_info, agg) + + if agg.name in ("sum", "count"): + agg_requests_pwise.append(ne) + agg_requests_tree.append( + NamedExpr( + name, + Cast( + dtype, + Agg(dtype, "sum", agg.options, Col(dtype, name)), + ), + ) + ) + elif agg.name == "mean": + name_map[name] = {agg.name: {}} + for sub in ["sum", "count"]: + # Partwise + tmp_name = f"{name}__{sub}" + name_map[name][agg.name][sub] = tmp_name + agg_pwise = Agg(dtype, sub, agg.options, *agg.children) + agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise)) + # Tree + child = Col(dtype, tmp_name) + agg_tree = Agg(dtype, "sum", agg.options, child) + agg_requests_tree.append(NamedExpr(tmp_name, agg_tree)) + else: + # Unsupported + return _single_fallback( + ir, children, partition_info, agg + ) # pragma: no cover + + gb_pwise = GroupBy( + ir.schema, + ir.keys, + agg_requests_pwise, + ir.maintain_order, + ir.options, + *children, + ) + child_count = partition_info[children[0]].count + partition_info[gb_pwise] = PartitionInfo(count=child_count) + + gb_tree = GroupByTree( + ir.schema, + ir.keys, + agg_requests_tree, + ir.maintain_order, + ir.options, + gb_pwise, + ) + partition_info[gb_tree] = PartitionInfo(count=1) + + schema = ir.schema + output_exprs = [] + for name, dtype in schema.items(): + agg_mapping = name_map.get(name, None) + if agg_mapping is None: + output_exprs.append(NamedExpr(name, Col(dtype, name))) + elif "mean" in agg_mapping: + mean_cols = agg_mapping["mean"] + output_exprs.append( + NamedExpr( + name, + BinOp( + dtype, + plc.binaryop.BinaryOperator.DIV, + Col(dtype, mean_cols["sum"]), + Col(dtype, mean_cols["count"]), + ), + ) + ) + should_broadcast: bool = False + new_node = Select( + schema, + output_exprs, + should_broadcast, + gb_tree, + ) + partition_info[new_node] = PartitionInfo(count=1) + return new_node, partition_info + + +def _tree_node(do_evaluate, batch, *args): + return do_evaluate(*args, _concat(batch)) + + +@generate_ir_tasks.register(GroupByTree) +def _( + ir: GroupByTree, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + child = ir.children[0] + child_count = partition_info[child].count + child_name = get_key_name(child) + name = get_key_name(ir) + + # Simple tree reduction. + j = 0 + graph: MutableMapping[Any, Any] = {} + split_every = 32 + keys: list[Any] = [(child_name, i) for i in range(child_count)] + while len(keys) > split_every: + new_keys: list[Any] = [] + for i, k in enumerate(range(0, len(keys), split_every)): + batch = keys[k : k + split_every] + graph[(name, j, i)] = ( + _tree_node, + ir.do_evaluate, + batch, + *ir._non_child_args, + ) + new_keys.append((name, j, i)) + j += 1 + keys = new_keys + graph[(name, 0)] = ( + _tree_node, + ir.do_evaluate, + keys, + *ir._non_child_args, + ) + return graph diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index e5884f1c574..45b76419274 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -9,8 +9,9 @@ from functools import reduce from typing import TYPE_CHECKING, Any +import cudf_polars.experimental.groupby import cudf_polars.experimental.io # noqa: F401 -from cudf_polars.dsl.ir import IR, Cache, Projection, Union +from cudf_polars.dsl.ir import IR, Cache, GroupBy, Projection, Union from cudf_polars.dsl.traversal import CachingVisitor, traversal from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name from cudf_polars.experimental.dispatch import ( @@ -243,5 +244,6 @@ def _generate_ir_tasks_pwise( } +generate_ir_tasks.register(GroupBy, _generate_ir_tasks_pwise) generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise) generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise) diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py new file mode 100644 index 00000000000..6c4c62b0392 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(scope="module") +def engine(): + return pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": 4}, + ) + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "x": range(150), + "y": ["cat", "dog", "fish"] * 50, + "z": [1.0, 2.0, 3.0, 4.0, 5.0] * 30, + } + ) + + +@pytest.mark.parametrize("op", ["sum", "mean", "len"]) +@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) +def test_groupby(df, engine, op, keys): + q = getattr(df.group_by(*keys), op)() + assert_gpu_result_equal(q, engine=engine, check_row_order=False) + + +@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"]) +@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) +def test_groupby_agg(df, engine, op, keys): + q = df.group_by(*keys).agg(getattr(pl.col("x"), op)()) + assert_gpu_result_equal(q, engine=engine, check_row_order=False) + + +def test_groupby_raises(df, engine): + q = df.group_by("y").median() + with pytest.raises( + pl.exceptions.ComputeError, + match="NotImplementedError", + ): + assert_gpu_result_equal(q, engine=engine, check_row_order=False) From b1224a03c75d07f4734daf86b574b9d074228ffe Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 4 Dec 2024 14:23:19 -0800 Subject: [PATCH 02/30] remove GroupbyTree --- .../cudf_polars/experimental/groupby.py | 25 +++++++++++++------ .../cudf_polars/experimental/parallel.py | 3 +-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 2f151dc8886..1c2c367631a 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -23,10 +23,6 @@ from cudf_polars.experimental.parallel import LowerIRTransformer -class GroupByTree(GroupBy): - """Groupby tree-reduction operation.""" - - _GB_AGG_SUPPORTED = ("sum", "count", "mean") @@ -130,7 +126,7 @@ def _( child_count = partition_info[children[0]].count partition_info[gb_pwise] = PartitionInfo(count=child_count) - gb_tree = GroupByTree( + gb_tree = GroupBy( ir.schema, ir.keys, agg_requests_tree, @@ -174,19 +170,32 @@ def _tree_node(do_evaluate, batch, *args): return do_evaluate(*args, _concat(batch)) -@generate_ir_tasks.register(GroupByTree) +@generate_ir_tasks.register(GroupBy) def _( - ir: GroupByTree, partition_info: MutableMapping[IR, PartitionInfo] + ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo] ) -> MutableMapping[Any, Any]: child = ir.children[0] child_count = partition_info[child].count child_name = get_key_name(child) - name = get_key_name(ir) + output_count = partition_info[ir].count + + if output_count == child_count: + return { + key: ( + ir.do_evaluate, + *ir._non_child_args, + (child_name, i), + ) + for i, key in enumerate(partition_info[ir].keys(ir)) + } + elif output_count != 1: # pragma: no cover + raise ValueError(f"Expected single partition, got {output_count}") # Simple tree reduction. j = 0 graph: MutableMapping[Any, Any] = {} split_every = 32 + name = get_key_name(ir) keys: list[Any] = [(child_name, i) for i in range(child_count)] while len(keys) > split_every: new_keys: list[Any] = [] diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 45b76419274..b48aa1984de 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -11,7 +11,7 @@ import cudf_polars.experimental.groupby import cudf_polars.experimental.io # noqa: F401 -from cudf_polars.dsl.ir import IR, Cache, GroupBy, Projection, Union +from cudf_polars.dsl.ir import IR, Cache, Projection, Union from cudf_polars.dsl.traversal import CachingVisitor, traversal from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name from cudf_polars.experimental.dispatch import ( @@ -244,6 +244,5 @@ def _generate_ir_tasks_pwise( } -generate_ir_tasks.register(GroupBy, _generate_ir_tasks_pwise) generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise) generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise) From 385f03af6ae7acb5454c8d15501a002a049eb58e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 5 Dec 2024 20:19:06 -0800 Subject: [PATCH 03/30] simplify lower --- .../cudf_polars/experimental/groupby.py | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 1c2c367631a..e76d9411ac6 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -4,8 +4,6 @@ from __future__ import annotations -import operator -from functools import reduce from typing import TYPE_CHECKING, Any import pylibcudf as plc @@ -28,38 +26,37 @@ def _single_fallback( ir: IR, - children: tuple[IR], + child: IR, partition_info: MutableMapping[IR, PartitionInfo], unsupported_agg: Expr | None = None, ): - if any(partition_info[child].count > 1 for child in children): # pragma: no cover + if partition_info[child].count > 1: msg = f"Class {type(ir)} does not support multiple partitions." if unsupported_agg: msg = msg[:-1] + f" with {unsupported_agg} expression." raise NotImplementedError(msg) - - new_node = ir.reconstruct(children) - partition_info[new_node] = PartitionInfo(count=1) - return new_node, partition_info + else: # pragma: no cover + new_node = ir.reconstruct([child]) + partition_info[new_node] = PartitionInfo(count=1) + return new_node, partition_info @lower_ir_node.register(GroupBy) def _( ir: GroupBy, rec: LowerIRTransformer ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - # Lower children - children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) - partition_info = reduce(operator.or_, _partition_info) + (child,) = ir.children + child, partition_info = rec(child) - if partition_info[children[0]].count == 1: + if partition_info[child].count == 1: # Single partition - return _single_fallback(ir, children, partition_info) + return _single_fallback(ir, child, partition_info) # Check that we are grouping on element-wise # keys (is this already guaranteed?) for ne in ir.keys: if not isinstance(ne.value, Col): # pragma: no cover - return _single_fallback(ir, children, partition_info) + return _single_fallback(ir, child, partition_info) name_map: MutableMapping[str, Any] = {} agg_tree: Cast | Agg | None = None @@ -84,7 +81,7 @@ def _( ) elif isinstance(agg, Agg): if agg.name not in _GB_AGG_SUPPORTED: - return _single_fallback(ir, children, partition_info, agg) + return _single_fallback(ir, child, partition_info, agg) if agg.name in ("sum", "count"): agg_requests_pwise.append(ne) @@ -106,14 +103,11 @@ def _( agg_pwise = Agg(dtype, sub, agg.options, *agg.children) agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise)) # Tree - child = Col(dtype, tmp_name) - agg_tree = Agg(dtype, "sum", agg.options, child) + agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) agg_requests_tree.append(NamedExpr(tmp_name, agg_tree)) else: # Unsupported - return _single_fallback( - ir, children, partition_info, agg - ) # pragma: no cover + return _single_fallback(ir, child, partition_info, agg) # pragma: no cover gb_pwise = GroupBy( ir.schema, @@ -121,9 +115,9 @@ def _( agg_requests_pwise, ir.maintain_order, ir.options, - *children, + child, ) - child_count = partition_info[children[0]].count + child_count = partition_info[child].count partition_info[gb_pwise] = PartitionInfo(count=child_count) gb_tree = GroupBy( @@ -174,7 +168,7 @@ def _tree_node(do_evaluate, batch, *args): def _( ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo] ) -> MutableMapping[Any, Any]: - child = ir.children[0] + (child,) = ir.children child_count = partition_info[child].count child_name = get_key_name(child) output_count = partition_info[ir].count From 3f04eca242ff8faca23ff3e463cf63aa0580dec6 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 19 Dec 2024 10:37:07 -0800 Subject: [PATCH 04/30] cleanup --- .../cudf_polars/experimental/groupby.py | 59 ++++++++----------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index e76d9411ac6..18404f34166 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -10,6 +10,7 @@ from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr from cudf_polars.dsl.ir import GroupBy, Select +from cudf_polars.dsl.traversal import traversal from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node @@ -21,42 +22,26 @@ from cudf_polars.experimental.parallel import LowerIRTransformer +# Supported multi-partition aggregations _GB_AGG_SUPPORTED = ("sum", "count", "mean") -def _single_fallback( - ir: IR, - child: IR, - partition_info: MutableMapping[IR, PartitionInfo], - unsupported_agg: Expr | None = None, -): - if partition_info[child].count > 1: - msg = f"Class {type(ir)} does not support multiple partitions." - if unsupported_agg: - msg = msg[:-1] + f" with {unsupported_agg} expression." - raise NotImplementedError(msg) - else: # pragma: no cover - new_node = ir.reconstruct([child]) - partition_info[new_node] = PartitionInfo(count=1) - return new_node, partition_info - - @lower_ir_node.register(GroupBy) def _( ir: GroupBy, rec: LowerIRTransformer ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - (child,) = ir.children - child, partition_info = rec(child) + # Extract child partitioning + child, partition_info = rec(ir.children[0]) + # Handle single-partition case if partition_info[child].count == 1: - # Single partition - return _single_fallback(ir, child, partition_info) + single_part_node = ir.reconstruct([child]) + partition_info[single_part_node] = partition_info[child] + return single_part_node, partition_info - # Check that we are grouping on element-wise - # keys (is this already guaranteed?) - for ne in ir.keys: - if not isinstance(ne.value, Col): # pragma: no cover - return _single_fallback(ir, child, partition_info) + # Check group-by keys + if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])): + raise NotImplementedError(f"GroupBy {ir} does not support multiple partitions.") name_map: MutableMapping[str, Any] = {} agg_tree: Cast | Agg | None = None @@ -81,7 +66,10 @@ def _( ) elif isinstance(agg, Agg): if agg.name not in _GB_AGG_SUPPORTED: - return _single_fallback(ir, child, partition_info, agg) + raise NotImplementedError( + f"GroupBy {ir} does not support multiple partitions " + f"with an {agg} expression." + ) if agg.name in ("sum", "count"): agg_requests_pwise.append(ne) @@ -106,8 +94,11 @@ def _( agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) agg_requests_tree.append(NamedExpr(tmp_name, agg_tree)) else: - # Unsupported - return _single_fallback(ir, child, partition_info, agg) # pragma: no cover + # Unsupported expression + raise NotImplementedError( + f"GroupBy {ir} does not support multiple partitions " + f"with an {agg} expression." + ) gb_pwise = GroupBy( ir.schema, @@ -185,16 +176,16 @@ def _( elif output_count != 1: # pragma: no cover raise ValueError(f"Expected single partition, got {output_count}") - # Simple tree reduction. + # Simple N-ary tree reduction j = 0 graph: MutableMapping[Any, Any] = {} - split_every = 32 + n_ary = 32 # TODO: Make this configurable name = get_key_name(ir) keys: list[Any] = [(child_name, i) for i in range(child_count)] - while len(keys) > split_every: + while len(keys) > n_ary: new_keys: list[Any] = [] - for i, k in enumerate(range(0, len(keys), split_every)): - batch = keys[k : k + split_every] + for i, k in enumerate(range(0, len(keys), n_ary)): + batch = keys[k : k + n_ary] graph[(name, j, i)] = ( _tree_node, ir.do_evaluate, From e090de53781112bc341c5df880fb9abc7970ffbf Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 19 Dec 2024 10:39:06 -0800 Subject: [PATCH 05/30] no cover --- python/cudf_polars/cudf_polars/experimental/groupby.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 18404f34166..eb55ca369a0 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -41,7 +41,9 @@ def _( # Check group-by keys if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])): - raise NotImplementedError(f"GroupBy {ir} does not support multiple partitions.") + raise NotImplementedError( + f"GroupBy {ir} does not support multiple partitions." + ) # pragma: no cover name_map: MutableMapping[str, Any] = {} agg_tree: Cast | Agg | None = None @@ -98,7 +100,7 @@ def _( raise NotImplementedError( f"GroupBy {ir} does not support multiple partitions " f"with an {agg} expression." - ) + ) # pragma: no cover gb_pwise = GroupBy( ir.schema, From 24b88f2deb402ac74eb044cd70a29a6891cbffb7 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 19 Dec 2024 10:51:03 -0800 Subject: [PATCH 06/30] tweak error message --- .../cudf_polars/cudf_polars/experimental/groupby.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index eb55ca369a0..d26bdada3ee 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -42,7 +42,8 @@ def _( # Check group-by keys if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])): raise NotImplementedError( - f"GroupBy {ir} does not support multiple partitions." + "GroupBy does not support multiple partitions " + f"for these keys:\n{ir.keys}" ) # pragma: no cover name_map: MutableMapping[str, Any] = {} @@ -69,8 +70,8 @@ def _( elif isinstance(agg, Agg): if agg.name not in _GB_AGG_SUPPORTED: raise NotImplementedError( - f"GroupBy {ir} does not support multiple partitions " - f"with an {agg} expression." + "GroupBy does not support multiple partitions " + f"for this expression:\n{agg}" ) if agg.name in ("sum", "count"): @@ -98,8 +99,8 @@ def _( else: # Unsupported expression raise NotImplementedError( - f"GroupBy {ir} does not support multiple partitions " - f"with an {agg} expression." + "GroupBy does not support multiple partitions " + f"for this expression:\n{agg}" ) # pragma: no cover gb_pwise = GroupBy( From 69f63364c3ba219271fbf23c443585892c99ed67 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 9 Jan 2025 11:46:47 -0800 Subject: [PATCH 07/30] update copyright dates --- python/cudf_polars/cudf_polars/experimental/groupby.py | 2 +- python/cudf_polars/cudf_polars/experimental/parallel.py | 2 +- python/cudf_polars/tests/experimental/test_groupby.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index d26bdada3ee..496e3d5fb4b 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Parallel GroupBy Logic.""" diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 8852f3e1f60..c4bc319f880 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Multi-partition Dask execution.""" diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index 6c4c62b0392..150344b0017 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations From 22cebeb4d285f2426e6e3937be61f972603b6e3e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Sat, 11 Jan 2025 05:50:36 -0800 Subject: [PATCH 08/30] add test coverage for single-partition --- .../tests/experimental/test_groupby.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index 150344b0017..b17b8ac885e 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -37,6 +37,21 @@ def test_groupby(df, engine, op, keys): assert_gpu_result_equal(q, engine=engine, check_row_order=False) +@pytest.mark.parametrize("op", ["sum", "mean", "len"]) +@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) +def test_groupby_single_partitions(df, op, keys): + q = getattr(df.group_by(*keys), op)() + assert_gpu_result_equal( + q, + engine=pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": 1e9}, + ), + check_row_order=False, + ) + + @pytest.mark.parametrize("op", ["sum", "mean", "len", "count"]) @pytest.mark.parametrize("keys", [("y",), ("y", "z")]) def test_groupby_agg(df, engine, op, keys): From b8a20e6f79c89f1ed08bbf82adb5f6939d994fd8 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 25 Feb 2025 08:22:04 -0800 Subject: [PATCH 09/30] formatting --- python/cudf_polars/cudf_polars/experimental/groupby.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 496e3d5fb4b..42b374c97c7 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -42,8 +42,7 @@ def _( # Check group-by keys if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])): raise NotImplementedError( - "GroupBy does not support multiple partitions " - f"for these keys:\n{ir.keys}" + f"GroupBy does not support multiple partitions for keys:\n{ir.keys}" ) # pragma: no cover name_map: MutableMapping[str, Any] = {} From 7d18e7bcde71357fe61c2ce18bce346be8ff8ba1 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 27 Feb 2025 12:14:59 -0800 Subject: [PATCH 10/30] add shuffle-based groupby --- python/cudf_polars/cudf_polars/callback.py | 4 +- python/cudf_polars/cudf_polars/dsl/ir.py | 26 +++++- .../cudf_polars/cudf_polars/dsl/translate.py | 1 + .../cudf_polars/experimental/groupby.py | 92 +++++++++++++++---- .../tests/experimental/test_groupby.py | 18 ++++ 5 files changed, 120 insertions(+), 21 deletions(-) diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 074096446fd..3536c9345dc 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Callback for the polars collect function to execute on device.""" @@ -233,6 +233,8 @@ def validate_config_options(config: dict) -> None: unsupported = config.get("executor_options", {}).keys() - { "max_rows_per_partition", "parquet_blocksize", + "cardinality_factor", + "groupby_n_ary", } else: unsupported = config.get("executor_options", {}).keys() diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 603f51e9d40..ee432583607 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -855,11 +855,19 @@ def __init__(self, polars_groupby_options: Any): __slots__ = ( "agg_infos", "agg_requests", + "config_options", "keys", "maintain_order", "options", ) - _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options") + _non_child = ( + "schema", + "keys", + "agg_requests", + "maintain_order", + "options", + "config_options", + ) keys: tuple[expr.NamedExpr, ...] """Grouping keys.""" agg_requests: tuple[expr.NamedExpr, ...] @@ -868,6 +876,8 @@ def __init__(self, polars_groupby_options: Any): """Preserve order in groupby.""" options: GroupbyOptions """Arbitrary options.""" + config_options: dict[str, Any] + """GPU-specific configuration options""" def __init__( self, @@ -876,6 +886,7 @@ def __init__( agg_requests: Sequence[expr.NamedExpr], maintain_order: bool, # noqa: FBT001 options: Any, + config_options: dict[str, Any], df: IR, ): self.schema = schema @@ -883,6 +894,7 @@ def __init__( self.agg_requests = tuple(agg_requests) self.maintain_order = maintain_order self.options = self.GroupbyOptions(options) + self.config_options = config_options self.children = (df,) if self.options.rolling: raise NotImplementedError( @@ -900,6 +912,18 @@ def __init__( self.AggInfos(self.agg_requests), ) + def get_hashable(self) -> Hashable: + """Hashable representation of the node.""" + return ( + type(self), + tuple(self.schema.items()), + self.keys, + self.maintain_order, + self.options, + json.dumps(self.config_options), + self.children, + ) + @staticmethod def check_agg(agg: expr.Expr) -> int: """ diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 369328d3a8c..ba97a7eccd4 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -288,6 +288,7 @@ def _( aggs, node.maintain_order, node.options, + translator.config.config.copy(), inp, ) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 42b374c97c7..16e38edbb8d 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -8,11 +8,12 @@ import pylibcudf as plc -from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr +from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr, UnaryFunction from cudf_polars.dsl.ir import GroupBy, Select from cudf_polars.dsl.traversal import traversal from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node +from cudf_polars.experimental.shuffle import Shuffle if TYPE_CHECKING: from collections.abc import MutableMapping @@ -45,10 +46,33 @@ def _( f"GroupBy does not support multiple partitions for keys:\n{ir.keys}" ) # pragma: no cover + # Check if we are dealing with any high-cardinality columns + post_aggregation_count = 1 # Default tree reduction + groupby_key_columns = [ne.name for ne in ir.keys] + cardinality_factor = { + c: min(f, 1.0) + for c, f in ir.config_options.get("executor_options", {}) + .get("cardinality_factor", {}) + .items() + if c in groupby_key_columns + } + if cardinality_factor: + # The `cardinality_factor` dictionary can be used + # to specify a mapping between column names and + # cardinality "factors". Each factor estimates the + # fractional number of unique values in the column. + # Each value should be in the range (0, 1]. + child_count = partition_info[child].count + post_aggregation_count = max( + int(max(cardinality_factor.values()) * child_count), + 1, + ) + name_map: MutableMapping[str, Any] = {} agg_tree: Cast | Agg | None = None agg_requests_pwise = [] # Partition-wise requests agg_requests_tree = [] # Tree-node requests + unary_ops: dict[str, dict[str, Any]] = {} for ne in ir.agg_requests: name = ne.name @@ -66,7 +90,16 @@ def _( ), ) ) - elif isinstance(agg, Agg): + elif isinstance(agg, (Agg, UnaryFunction)): + if ( + isinstance(agg, UnaryFunction) + and agg.is_pointwise + and isinstance(agg.children[0], Agg) + ): + # TODO: Handle sequential unary ops + unary_ops[name] = {"name": agg.name, "options": agg.options} + agg = agg.children[0] + if agg.name not in _GB_AGG_SUPPORTED: raise NotImplementedError( "GroupBy does not support multiple partitions " @@ -102,54 +135,75 @@ def _( f"for this expression:\n{agg}" ) # pragma: no cover + # Partition-wise groupby operation gb_pwise = GroupBy( ir.schema, ir.keys, agg_requests_pwise, ir.maintain_order, ir.options, + ir.config_options, child, ) child_count = partition_info[child].count partition_info[gb_pwise] = PartitionInfo(count=child_count) - gb_tree = GroupBy( + # Add Shuffle node if necessary + gb_inter: GroupBy | Shuffle = gb_pwise + if post_aggregation_count > 1: + shuffle_options: dict[str, Any] = {} + gb_inter = Shuffle( + ir.schema, + ir.keys, + shuffle_options, + gb_pwise, + ) + partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count) + + # Tree reduction if post_aggregation_count==1 + # (Otherwise, this is another partition-wise op) + gb_reduce = GroupBy( ir.schema, ir.keys, agg_requests_tree, ir.maintain_order, ir.options, - gb_pwise, + ir.config_options, + gb_inter, ) - partition_info[gb_tree] = PartitionInfo(count=1) + partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count) schema = ir.schema output_exprs = [] + col_expr: Col | BinOp | UnaryFunction for name, dtype in schema.items(): agg_mapping = name_map.get(name, None) if agg_mapping is None: - output_exprs.append(NamedExpr(name, Col(dtype, name))) + col_expr = Col(dtype, name) elif "mean" in agg_mapping: mean_cols = agg_mapping["mean"] - output_exprs.append( - NamedExpr( - name, - BinOp( - dtype, - plc.binaryop.BinaryOperator.DIV, - Col(dtype, mean_cols["sum"]), - Col(dtype, mean_cols["count"]), - ), - ) + col_expr = BinOp( + dtype, + plc.binaryop.BinaryOperator.DIV, + Col(dtype, mean_cols["sum"]), + Col(dtype, mean_cols["count"]), + ) + if name in unary_ops: + col_expr = UnaryFunction( + dtype, + unary_ops[name]["name"], + unary_ops[name]["options"], + col_expr, ) + output_exprs.append(NamedExpr(name, col_expr)) should_broadcast: bool = False new_node = Select( schema, output_exprs, should_broadcast, - gb_tree, + gb_reduce, ) - partition_info[new_node] = PartitionInfo(count=1) + partition_info[new_node] = PartitionInfo(count=post_aggregation_count) return new_node, partition_info @@ -180,8 +234,8 @@ def _( # Simple N-ary tree reduction j = 0 + n_ary = ir.config_options.get("executor_options", {}).get("groupby_n_ary", 32) graph: MutableMapping[Any, Any] = {} - n_ary = 32 # TODO: Make this configurable name = get_key_name(ir) keys: list[Any] = [(child_name, i) for i in range(child_count)] while len(keys) > n_ary: diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index b17b8ac885e..07a6f481eaa 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -59,6 +59,24 @@ def test_groupby_agg(df, engine, op, keys): assert_gpu_result_equal(q, engine=engine, check_row_order=False) +@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"]) +@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) +def test_groupby_agg_config_options(df, op, keys): + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={ + "max_rows_per_partition": 4, + # Trigger shuffle-based groupby + "cardinality_factor": {"z": 0.5}, + # Check that we can change the n-ary factor + "groupby_n_ary": 8, + }, + ) + q = df.group_by(*keys).agg(getattr(pl.col("x"), op)()) + assert_gpu_result_equal(q, engine=engine, check_row_order=False) + + def test_groupby_raises(df, engine): q = df.group_by("y").median() with pytest.raises( From ee47bd9f362cc4e1366f74e6e385bba9e674b957 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 28 Feb 2025 02:06:32 -0800 Subject: [PATCH 11/30] Add `pylibcudf.gpumemoryview` support for `len()`/`nbytes` --- python/pylibcudf/pylibcudf/gpumemoryview.pyx | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/pylibcudf/pylibcudf/gpumemoryview.pyx b/python/pylibcudf/pylibcudf/gpumemoryview.pyx index 41316eddb60..dad91361dd2 100644 --- a/python/pylibcudf/pylibcudf/gpumemoryview.pyx +++ b/python/pylibcudf/pylibcudf/gpumemoryview.pyx @@ -1,4 +1,7 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. + +import functools +import operator __all__ = ["gpumemoryview"] @@ -27,4 +30,19 @@ cdef class gpumemoryview: def __cuda_array_interface__(self): return self.obj.__cuda_array_interface__ + def __len__(self): + return self.obj.__cuda_array_interface["shape"][0] + + @property + def nbytes(self): + cai = self.obj.__cuda_array_interface__ + shape, typestr = cai["shape"], cai["typestr"] + + # Get element size from typestr, format is two character specifying + # the type and the latter part is the number of bytes. E.g., ' Date: Fri, 28 Feb 2025 06:38:14 -0800 Subject: [PATCH 12/30] improve test coverage --- python/cudf_polars/tests/experimental/test_groupby.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index 07a6f481eaa..11c5c432b11 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -73,7 +73,10 @@ def test_groupby_agg_config_options(df, op, keys): "groupby_n_ary": 8, }, ) - q = df.group_by(*keys).agg(getattr(pl.col("x"), op)()) + agg = getattr(pl.col("x"), op)() + if op in ("sum", "mean"): + agg = agg.round(2) # Unary test coverage + q = df.group_by(*keys).agg(agg) assert_gpu_result_equal(q, engine=engine, check_row_order=False) From 33bf65c55ced04d54f15bfcc30bc96e7ea537c78 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 28 Feb 2025 06:24:08 -0800 Subject: [PATCH 13/30] Add gpumemoryview tests for `len()`/`nbytes` --- .../pylibcudf/tests/test_gpumemoryview.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py diff --git a/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py b/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py new file mode 100644 index 00000000000..ddf165cc9f7 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +import itertools + +import numpy as np +import pytest + +import rmm + +import pylibcudf as plc + +DTYPES = [ + "u1", + "i2", + "f4", + "f8", + "f16", +] +SIZES = [ + 0, + 1, + 1000, + 1024, + 10000, +] + + +@pytest.fixture(params=tuple(itertools.product(SIZES, DTYPES)), ids=repr) +def np_array(request): + size, dtype = request.param + return np.empty((size,), dtype=dtype) + + +def test_len_nbytes(np_array): + buf = rmm.DeviceBuffer( + ptr=np_array.__array_interface__["data"][0], size=np_array.nbytes + ) + gpumemview = plc.gpumemoryview(buf) + + np_array_view = np_array.view("u1") + + assert len(gpumemview) == len(np_array_view) + assert gpumemview.nbytes == np_array.nbytes From ac73babc3eab07198dd4e41256da4ad71cc45651 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 28 Feb 2025 07:03:29 -0800 Subject: [PATCH 14/30] Add `gpumemoryview.__cuda_array_interface__` tests --- .../pylibcudf/tests/test_gpumemoryview.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py b/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py index ddf165cc9f7..187857c935a 100644 --- a/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py +++ b/python/pylibcudf/pylibcudf/tests/test_gpumemoryview.py @@ -31,7 +31,22 @@ def np_array(request): return np.empty((size,), dtype=dtype) -def test_len_nbytes(np_array): +def test_cuda_array_interface(np_array): + buf = rmm.DeviceBuffer( + ptr=np_array.__array_interface__["data"][0], size=np_array.nbytes + ) + gpumemview = plc.gpumemoryview(buf) + + np_array_view = np_array.view("u1") + + ai = np_array_view.__array_interface__ + cai = gpumemview.__cuda_array_interface__ + assert cai["shape"] == ai["shape"] + assert cai["strides"] == ai["strides"] + assert cai["typestr"] == ai["typestr"] + + +def test_len(np_array): buf = rmm.DeviceBuffer( ptr=np_array.__array_interface__["data"][0], size=np_array.nbytes ) From 41844b7da4c6a0cd8fc7459087b9231e6a9a20e4 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 28 Feb 2025 11:23:45 -0800 Subject: [PATCH 15/30] Update stubs --- python/pylibcudf/pylibcudf/gpumemoryview.pyi | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pylibcudf/pylibcudf/gpumemoryview.pyi b/python/pylibcudf/pylibcudf/gpumemoryview.pyi index 50f1f39a515..236ff6e56a6 100644 --- a/python/pylibcudf/pylibcudf/gpumemoryview.pyi +++ b/python/pylibcudf/pylibcudf/gpumemoryview.pyi @@ -7,3 +7,6 @@ class gpumemoryview: def __init__(self, data: Any): ... @property def __cuda_array_interface__(self) -> Mapping[str, Any]: ... + def __len__(self) -> int: ... + @property + def nbytes(self) -> int: ... From 6dfb397a2f1886880fda7cf6111eef6101b2a7a8 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 28 Feb 2025 11:58:02 -0800 Subject: [PATCH 16/30] add ConfigOptions class --- python/cudf_polars/cudf_polars/callback.py | 41 +----- python/cudf_polars/cudf_polars/dsl/ir.py | 28 ++-- .../cudf_polars/cudf_polars/dsl/translate.py | 12 +- .../cudf_polars/experimental/io.py | 19 +-- .../cudf_polars/cudf_polars/utils/config.py | 129 ++++++++++++++++++ 5 files changed, 161 insertions(+), 68 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/utils/config.py diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 074096446fd..1eb17d20806 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Callback for the polars collect function to execute on device.""" @@ -202,44 +202,6 @@ def _callback( raise ValueError(f"Unknown executor '{executor}'") -def validate_config_options(config: dict) -> None: - """ - Validate the configuration options for the GPU engine. - - Parameters - ---------- - config - Configuration options to validate. - - Raises - ------ - ValueError - If the configuration contains unsupported options. - """ - if unsupported := ( - config.keys() - - {"raise_on_fail", "parquet_options", "executor", "executor_options"} - ): - raise ValueError( - f"Engine configuration contains unsupported settings: {unsupported}" - ) - assert {"chunked", "chunk_read_limit", "pass_read_limit"}.issuperset( - config.get("parquet_options", {}) - ) - - # Validate executor_options - executor = config.get("executor", "pylibcudf") - if executor == "dask-experimental": - unsupported = config.get("executor_options", {}).keys() - { - "max_rows_per_partition", - "parquet_blocksize", - } - else: - unsupported = config.get("executor_options", {}).keys() - if unsupported: - raise ValueError(f"Unsupported executor_options for {executor}: {unsupported}") - - def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: """ A post optimization callback that attempts to execute the plan with cudf. @@ -267,7 +229,6 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: memory_resource = config.memory_resource raise_on_fail = config.config.get("raise_on_fail", False) executor = config.config.get("executor", None) - validate_config_options(config.config) with nvtx.annotate(message="ConvertIR", domain="cudf_polars"): translator = Translator(nt, config) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 603f51e9d40..4a71f06bab4 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -39,6 +39,7 @@ from polars.polars import _expr_nodes as pl_expr from cudf_polars.typing import Schema + from cudf_polars.utils.config import ConfigOptions __all__ = [ @@ -284,7 +285,7 @@ class Scan(IR): """Reader-specific options, as dictionary.""" cloud_options: dict[str, Any] | None """Cloud-related authentication options, currently ignored.""" - config_options: dict[str, Any] + config_options: ConfigOptions """GPU-specific configuration options""" paths: list[str] """List of paths to read from.""" @@ -308,7 +309,7 @@ def __init__( typ: str, reader_options: dict[str, Any], cloud_options: dict[str, Any] | None, - config_options: dict[str, Any], + config_options: ConfigOptions, paths: list[str], with_columns: list[str] | None, skip_rows: int, @@ -413,7 +414,7 @@ def get_hashable(self) -> Hashable: self.typ, json.dumps(self.reader_options), json.dumps(self.cloud_options), - json.dumps(self.config_options), + self.config_options, tuple(self.paths), tuple(self.with_columns) if self.with_columns is not None else None, self.skip_rows, @@ -428,7 +429,7 @@ def do_evaluate( schema: Schema, typ: str, reader_options: dict[str, Any], - config_options: dict[str, Any], + config_options: ConfigOptions, paths: list[str], with_columns: list[str] | None, skip_rows: int, @@ -516,8 +517,7 @@ def do_evaluate( colnames[0], ) elif typ == "parquet": - parquet_options = config_options.get("parquet_options", {}) - if parquet_options.get("chunked", True): + if config_options.get("parquet_options.chunked", default=True): options = plc.io.parquet.ParquetReaderOptions.builder( plc.io.SourceInfo(paths) ).build() @@ -534,11 +534,13 @@ def do_evaluate( options.set_columns(with_columns) reader = plc.io.parquet.ChunkedParquetReader( options, - chunk_read_limit=parquet_options.get( - "chunk_read_limit", cls.PARQUET_DEFAULT_CHUNK_SIZE + chunk_read_limit=config_options.get( + "parquet_options.chunk_read_limit", + default=cls.PARQUET_DEFAULT_CHUNK_SIZE, ), - pass_read_limit=parquet_options.get( - "pass_read_limit", cls.PARQUET_DEFAULT_PASS_LIMIT + pass_read_limit=config_options.get( + "parquet_options.pass_read_limit", + default=cls.PARQUET_DEFAULT_PASS_LIMIT, ), ) chk = reader.read_chunk() @@ -702,7 +704,7 @@ class DataFrameScan(IR): """Polars LazyFrame object.""" projection: tuple[str, ...] | None """List of columns to project out.""" - config_options: dict[str, Any] + config_options: ConfigOptions """GPU-specific configuration options""" def __init__( @@ -710,7 +712,7 @@ def __init__( schema: Schema, df: Any, projection: Sequence[str] | None, - config_options: dict[str, Any], + config_options: ConfigOptions, ): self.schema = schema self.df = df @@ -736,7 +738,7 @@ def get_hashable(self) -> Hashable: schema_hash, id(self.df), self.projection, - json.dumps(self.config_options), + self.config_options, ) @classmethod diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 369328d3a8c..9ca0ea6a355 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -23,7 +23,7 @@ from cudf_polars.dsl import expr, ir from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser -from cudf_polars.utils import dtypes, sorting +from cudf_polars.utils import config, dtypes, sorting if TYPE_CHECKING: from polars import GPUEngine @@ -41,13 +41,13 @@ class Translator: ---------- visitor Polars NodeTraverser object - config + engine GPU engine configuration. """ - def __init__(self, visitor: NodeTraverser, config: GPUEngine): + def __init__(self, visitor: NodeTraverser, engine: GPUEngine): self.visitor = visitor - self.config = config + self.config_options = config.ConfigOptions(engine.config.copy()) self.errors: list[Exception] = [] def translate_ir(self, *, n: int | None = None) -> ir.IR: @@ -233,7 +233,7 @@ def _( typ, reader_options, cloud_options, - translator.config.config.copy(), + translator.config_options, node.paths, with_columns, skip_rows, @@ -260,7 +260,7 @@ def _( schema, node.df, node.projection, - translator.config.config.copy(), + translator.config_options, ) diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index ba4432ecdea..d61cad50685 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -22,14 +22,16 @@ from cudf_polars.dsl.expr import NamedExpr from cudf_polars.experimental.dispatch import LowerIRTransformer from cudf_polars.typing import Schema + from cudf_polars.utils.config import ConfigOptions @lower_ir_node.register(DataFrameScan) def _( ir: DataFrameScan, rec: LowerIRTransformer ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - rows_per_partition = ir.config_options.get("executor_options", {}).get( - "max_rows_per_partition", 1_000_000 + rows_per_partition = ir.config_options.get( + "executor_options.max_rows_per_partition", + default=1_000_000, ) nrows = max(ir.df.shape()[0], 1) @@ -91,8 +93,10 @@ def from_scan(ir: Scan) -> ScanPartitionPlan: """Extract the partitioning plan of a Scan operation.""" if ir.typ == "parquet": # TODO: Use system info to set default blocksize - parallel_options = ir.config_options.get("executor_options", {}) - blocksize: int = parallel_options.get("parquet_blocksize", 1024**3) + blocksize: int = ir.config_options.get( + "executor_options.parquet_blocksize", + default=1024**3, + ) stats = _sample_pq_statistics(ir) file_size = sum(float(stats[column]) for column in ir.schema) if file_size > 0: @@ -168,7 +172,7 @@ def do_evaluate( schema: Schema, typ: str, reader_options: dict[str, Any], - config_options: dict[str, Any], + config_options: ConfigOptions, paths: list[str], with_columns: list[str] | None, skip_rows: int, @@ -271,10 +275,7 @@ def _( if plan.flavor == ScanPartitionFlavor.SPLIT_FILES: # Disable chunked reader when splitting files config_options = ir.config_options.copy() - config_options["parquet_options"] = config_options.get( - "parquet_options", {} - ).copy() - config_options["parquet_options"]["chunked"] = False + config_options.set(name="parquet_options.chunked", value=False) slices: list[SplitScan] = [] for path in paths: diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py new file mode 100644 index 00000000000..898a06e8461 --- /dev/null +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Config utilities.""" + +from __future__ import annotations + +import copy +import json +from typing import Any, Self + +__all__ = ["ConfigOptions"] + + +class ConfigOptions: + """ + GPUEngine configuration-option manager. + + This is a conveniecne class to help manage the nested + dictionary of user-accessible `GPUEngine` options. + """ + + config_options: dict[str, Any] + """The underlying (nested) config-option dictionary.""" + + def __init__(self, options: dict[str, Any]): + self.validate(options) + self.config_options = options + + def copy(self) -> Self: + """Return a deep ConfigOptions copy.""" + return type(self)(copy.deepcopy(self.config_options.copy())) + + def set(self, name: str, value: Any) -> None: + """ + Set a user config option. + + Nested dictionary keys should be separated by periods. + For example:: + + >>> config_options.set("parquet_options.chunked", False) + + Parameters + ---------- + name + Period-separated config name. + value + New confiv value. + """ + options: dict[str, Any] = self.config_options + keys: list[str] = name.split(".") + for k in keys[:-1]: + assert isinstance(options, dict) + if k not in options: + options[k] = {} + options = options[k] + options[keys[-1]] = value + + def get(self, name: str, *, default: Any = None) -> Any: + """ + Get a user config option. + + Nested dictionary keys should be separated by periods. + For example:: + + >>> chunked = config_options.get("parquet_options.chunked") + + Parameters + ---------- + name + Period-separated config name. + default + Default return value. + + Returns + ------- + The user-specified config value, or `default` + if the config is not found. + """ + options: dict[str, Any] = self.config_options + keys: list[str] = name.split(".") + for k in keys[:-1]: + assert isinstance(options, dict) + options = options.get(k, {}) + return options.get(keys[-1], default) + + def __hash__(self) -> int: + """Hash a ConfigOptions object.""" + return hash(json.dumps(self.config_options)) + + @staticmethod + def validate(config: dict) -> None: + """ + Validate a configuration-option dictionary. + + Parameters + ---------- + config + GPUEngine configuration options to validate. + + Raises + ------ + ValueError + If the configuration contains unsupported options. + """ + if unsupported := ( + config.keys() + - {"raise_on_fail", "parquet_options", "executor", "executor_options"} + ): + raise ValueError( + f"Engine configuration contains unsupported settings: {unsupported}" + ) + assert {"chunked", "chunk_read_limit", "pass_read_limit"}.issuperset( + config.get("parquet_options", {}) + ) + + # Validate executor_options + executor = config.get("executor", "pylibcudf") + if executor == "dask-experimental": + unsupported = config.get("executor_options", {}).keys() - { + "max_rows_per_partition", + "parquet_blocksize", + } + else: + unsupported = config.get("executor_options", {}).keys() + if unsupported: + raise ValueError( + f"Unsupported executor_options for {executor}: {unsupported}" + ) From 3f00203e48847dbed9e8b7e0cb685fff1f2036a2 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 28 Feb 2025 12:07:10 -0800 Subject: [PATCH 17/30] Fix typo in `__cuda_array_interface__` name --- python/pylibcudf/pylibcudf/gpumemoryview.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibcudf/pylibcudf/gpumemoryview.pyx b/python/pylibcudf/pylibcudf/gpumemoryview.pyx index dad91361dd2..954d35a6ce3 100644 --- a/python/pylibcudf/pylibcudf/gpumemoryview.pyx +++ b/python/pylibcudf/pylibcudf/gpumemoryview.pyx @@ -31,7 +31,7 @@ cdef class gpumemoryview: return self.obj.__cuda_array_interface__ def __len__(self): - return self.obj.__cuda_array_interface["shape"][0] + return self.obj.__cuda_array_interface__["shape"][0] @property def nbytes(self): From 0a13145e9ed7535ffc9f491d4270107a422628f1 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Mar 2025 07:35:51 -0800 Subject: [PATCH 18/30] check for periods --- python/cudf_polars/cudf_polars/utils/config.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index 898a06e8461..cd022b5adf8 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -16,7 +16,7 @@ class ConfigOptions: """ GPUEngine configuration-option manager. - This is a conveniecne class to help manage the nested + This is a convenience class to help manage the nested dictionary of user-accessible `GPUEngine` options. """ @@ -45,7 +45,7 @@ def set(self, name: str, value: Any) -> None: name Period-separated config name. value - New confiv value. + New config value. """ options: dict[str, Any] = self.config_options keys: list[str] = name.split(".") @@ -127,3 +127,16 @@ def validate(config: dict) -> None: raise ValueError( f"Unsupported executor_options for {executor}: {unsupported}" ) + + # Check that keys are free of periods + def _find_periods(options: dict) -> None: + assert isinstance(options, dict) + for key, val in options.items(): + if isinstance(val, dict): + _find_periods(val) + if "." in key: + raise ValueError( + f"Configuration key cannot contain a period: {key}" + ) + + _find_periods(config) From a977df6192d91ef38e9697ce0898e78c5477ef54 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Mar 2025 07:41:50 -0800 Subject: [PATCH 19/30] roll back unnecessary change --- python/cudf_polars/cudf_polars/utils/config.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index cd022b5adf8..484f443de9b 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -127,16 +127,3 @@ def validate(config: dict) -> None: raise ValueError( f"Unsupported executor_options for {executor}: {unsupported}" ) - - # Check that keys are free of periods - def _find_periods(options: dict) -> None: - assert isinstance(options, dict) - for key, val in options.items(): - if isinstance(val, dict): - _find_periods(val) - if "." in key: - raise ValueError( - f"Configuration key cannot contain a period: {key}" - ) - - _find_periods(config) From e445e370adb6c279abeae7a48fb5eb6383212953 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 4 Mar 2025 11:46:02 -0800 Subject: [PATCH 20/30] remove copy API and make ConfigOptions immutable --- .../cudf_polars/experimental/io.py | 5 ++-- .../cudf_polars/cudf_polars/utils/config.py | 25 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index d61cad50685..a63817fbde0 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -274,8 +274,9 @@ def _( paths = list(ir.paths) if plan.flavor == ScanPartitionFlavor.SPLIT_FILES: # Disable chunked reader when splitting files - config_options = ir.config_options.copy() - config_options.set(name="parquet_options.chunked", value=False) + config_options = ir.config_options.set( + name="parquet_options.chunked", value=False + ) slices: list[SplitScan] = [] for path in paths: diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index 484f443de9b..54e6dff920b 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -20,6 +20,8 @@ class ConfigOptions: dictionary of user-accessible `GPUEngine` options. """ + __slots__ = ("_hash_value", "config_options") + _hash_value: int config_options: dict[str, Any] """The underlying (nested) config-option dictionary.""" @@ -27,18 +29,14 @@ def __init__(self, options: dict[str, Any]): self.validate(options) self.config_options = options - def copy(self) -> Self: - """Return a deep ConfigOptions copy.""" - return type(self)(copy.deepcopy(self.config_options.copy())) - - def set(self, name: str, value: Any) -> None: + def set(self, name: str, value: Any) -> Self: """ Set a user config option. Nested dictionary keys should be separated by periods. For example:: - >>> config_options.set("parquet_options.chunked", False) + >>> options = options.set("parquet_options.chunked", False) Parameters ---------- @@ -47,14 +45,15 @@ def set(self, name: str, value: Any) -> None: value New config value. """ - options: dict[str, Any] = self.config_options - keys: list[str] = name.split(".") + options = config_options = copy.deepcopy(self.config_options) + keys = name.split(".") for k in keys[:-1]: assert isinstance(options, dict) if k not in options: options[k] = {} options = options[k] options[keys[-1]] = value + return type(self)(config_options) def get(self, name: str, *, default: Any = None) -> Any: """ @@ -77,8 +76,8 @@ def get(self, name: str, *, default: Any = None) -> Any: The user-specified config value, or `default` if the config is not found. """ - options: dict[str, Any] = self.config_options - keys: list[str] = name.split(".") + options = self.config_options + keys = name.split(".") for k in keys[:-1]: assert isinstance(options, dict) options = options.get(k, {}) @@ -86,7 +85,11 @@ def get(self, name: str, *, default: Any = None) -> Any: def __hash__(self) -> int: """Hash a ConfigOptions object.""" - return hash(json.dumps(self.config_options)) + try: + return self._hash_value + except AttributeError: + self._hash_value = hash(json.dumps(self.config_options)) + return self._hash_value @staticmethod def validate(config: dict) -> None: From e9abe335eba0c7f35bd68873f73a6d6d2124a4ff Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 4 Mar 2025 13:10:23 -0800 Subject: [PATCH 21/30] use typing_extensions for older python versions --- python/cudf_polars/cudf_polars/utils/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index 54e6dff920b..41fb8836959 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -4,10 +4,11 @@ """Config utilities.""" from __future__ import annotations +from typing_extensions import Self import copy import json -from typing import Any, Self +from typing import Any __all__ = ["ConfigOptions"] From eb7a79a048a76fe0cc95d2f6331042a7fcb2a6e9 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 4 Mar 2025 13:11:56 -0800 Subject: [PATCH 22/30] formatting --- python/cudf_polars/cudf_polars/utils/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index 41fb8836959..d7ea0961689 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -4,11 +4,13 @@ """Config utilities.""" from __future__ import annotations -from typing_extensions import Self import copy import json -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing_extensions import Self __all__ = ["ConfigOptions"] From 385c68a8bca680a42b3211761a4044826219ae2e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 4 Mar 2025 14:48:01 -0800 Subject: [PATCH 23/30] break out the decomposition of a single groupby request into a stand-alone function --- .../cudf_polars/experimental/groupby.py | 212 ++++++++++-------- 1 file changed, 119 insertions(+), 93 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 16e38edbb8d..83334bda713 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -4,6 +4,8 @@ from __future__ import annotations +import itertools +from functools import partial from typing import TYPE_CHECKING, Any import pylibcudf as plc @@ -21,12 +23,110 @@ from cudf_polars.dsl.expr import Expr from cudf_polars.dsl.ir import IR from cudf_polars.experimental.parallel import LowerIRTransformer + from cudf_polars.typing import Schema # Supported multi-partition aggregations _GB_AGG_SUPPORTED = ("sum", "count", "mean") +def decompose_groupby_reduction( + schema: Schema, + request: NamedExpr, +) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]: + """Decompose a groupby request.""" + complex_expr_map: MutableMapping[str, Any] = {} + piecewise_exprs: list[NamedExpr] = [] + reduction_exprs: list[NamedExpr] = [] + selection_exprs: list[NamedExpr] = [] + unary_op: dict[str, Any] = {} + + name = request.name + agg: Expr = request.value + dtype = agg.dtype + agg = agg.children[0] if isinstance(agg, Cast) else agg + + if isinstance(agg, Len): + piecewise_exprs.append(request) + reduction_exprs.append( + NamedExpr( + name, + Cast( + dtype, + Agg(dtype, "sum", None, Col(dtype, name)), + ), + ) + ) + elif isinstance(agg, (Agg, UnaryFunction)): + if ( + isinstance(agg, UnaryFunction) + and agg.is_pointwise + and isinstance(agg.children[0], Agg) + ): + # TODO: Handle sequential unary ops + unary_op = {"name": agg.name, "options": agg.options} + agg = agg.children[0] + + if agg.name not in _GB_AGG_SUPPORTED: + raise NotImplementedError( + "GroupBy does not support multiple partitions " + f"for this expression:\n{agg}" + ) + + if agg.name in ("sum", "count"): + piecewise_exprs.append(request) + reduction_exprs.append( + NamedExpr( + name, + Cast( + dtype, + Agg(dtype, "sum", agg.options, Col(dtype, name)), + ), + ) + ) + elif agg.name == "mean": + complex_expr_map[name] = {"mean": {}} + for sub in ["sum", "count"]: + # Partwise + tmp_name = f"{name}__{sub}" + complex_expr_map[name]["mean"][sub] = tmp_name + agg_pwise = Agg(dtype, sub, agg.options, *agg.children) + piecewise_exprs.append(NamedExpr(tmp_name, agg_pwise)) + # Tree + agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) + reduction_exprs.append(NamedExpr(tmp_name, agg_tree)) + else: + # Unsupported expression + raise NotImplementedError( + f"GroupBy does not support multiple partitions for this expression:\n{agg}" + ) # pragma: no cover + + # Construct final selection expressions + col_expr: Col | BinOp | UnaryFunction + dtype = schema[name] + complex_expr = complex_expr_map.get(name, None) + if complex_expr is None: + col_expr = Col(dtype, name) + elif "mean" in complex_expr: + mean_cols = complex_expr["mean"] + col_expr = BinOp( + dtype, + plc.binaryop.BinaryOperator.DIV, + Col(dtype, mean_cols["sum"]), + Col(dtype, mean_cols["count"]), + ) + if unary_op: + col_expr = UnaryFunction( + dtype, + unary_op["name"], + unary_op["options"], + col_expr, + ) + selection_exprs.append(NamedExpr(name, col_expr)) + + return piecewise_exprs, reduction_exprs, selection_exprs + + @lower_ir_node.register(GroupBy) def _( ir: GroupBy, rec: LowerIRTransformer @@ -68,78 +168,22 @@ def _( 1, ) - name_map: MutableMapping[str, Any] = {} - agg_tree: Cast | Agg | None = None - agg_requests_pwise = [] # Partition-wise requests - agg_requests_tree = [] # Tree-node requests - unary_ops: dict[str, dict[str, Any]] = {} - - for ne in ir.agg_requests: - name = ne.name - agg: Expr = ne.value - dtype = agg.dtype - agg = agg.children[0] if isinstance(agg, Cast) else agg - if isinstance(agg, Len): - agg_requests_pwise.append(ne) - agg_requests_tree.append( - NamedExpr( - name, - Cast( - dtype, - Agg(dtype, "sum", None, Col(dtype, name)), - ), - ) - ) - elif isinstance(agg, (Agg, UnaryFunction)): - if ( - isinstance(agg, UnaryFunction) - and agg.is_pointwise - and isinstance(agg.children[0], Agg) - ): - # TODO: Handle sequential unary ops - unary_ops[name] = {"name": agg.name, "options": agg.options} - agg = agg.children[0] - - if agg.name not in _GB_AGG_SUPPORTED: - raise NotImplementedError( - "GroupBy does not support multiple partitions " - f"for this expression:\n{agg}" - ) - - if agg.name in ("sum", "count"): - agg_requests_pwise.append(ne) - agg_requests_tree.append( - NamedExpr( - name, - Cast( - dtype, - Agg(dtype, "sum", agg.options, Col(dtype, name)), - ), - ) - ) - elif agg.name == "mean": - name_map[name] = {agg.name: {}} - for sub in ["sum", "count"]: - # Partwise - tmp_name = f"{name}__{sub}" - name_map[name][agg.name][sub] = tmp_name - agg_pwise = Agg(dtype, sub, agg.options, *agg.children) - agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise)) - # Tree - agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) - agg_requests_tree.append(NamedExpr(tmp_name, agg_tree)) - else: - # Unsupported expression - raise NotImplementedError( - "GroupBy does not support multiple partitions " - f"for this expression:\n{agg}" - ) # pragma: no cover + piecewise_exprs, reduction_exprs, selection_exprs = ( + list(itertools.chain.from_iterable(x)) + for x in zip( + *map( + partial(decompose_groupby_reduction, ir.schema), + ir.agg_requests, + ), + strict=False, + ) + ) # Partition-wise groupby operation gb_pwise = GroupBy( ir.schema, ir.keys, - agg_requests_pwise, + piecewise_exprs, ir.maintain_order, ir.options, ir.config_options, @@ -165,7 +209,7 @@ def _( gb_reduce = GroupBy( ir.schema, ir.keys, - agg_requests_tree, + reduction_exprs, ir.maintain_order, ir.options, ir.config_options, @@ -173,33 +217,15 @@ def _( ) partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count) - schema = ir.schema - output_exprs = [] - col_expr: Col | BinOp | UnaryFunction - for name, dtype in schema.items(): - agg_mapping = name_map.get(name, None) - if agg_mapping is None: - col_expr = Col(dtype, name) - elif "mean" in agg_mapping: - mean_cols = agg_mapping["mean"] - col_expr = BinOp( - dtype, - plc.binaryop.BinaryOperator.DIV, - Col(dtype, mean_cols["sum"]), - Col(dtype, mean_cols["count"]), - ) - if name in unary_ops: - col_expr = UnaryFunction( - dtype, - unary_ops[name]["name"], - unary_ops[name]["options"], - col_expr, - ) - output_exprs.append(NamedExpr(name, col_expr)) should_broadcast: bool = False + aggregated = {ne.name: ne for ne in selection_exprs} new_node = Select( - schema, - output_exprs, + ir.schema, + [ + # Select the aggregated data or the original column + aggregated.get(name, NamedExpr(name, Col(dtype, name))) + for name, dtype in ir.schema.items() + ], should_broadcast, gb_reduce, ) From fe3ca7cfbc3fb37d2b8b6515e0d4012f1d7c0bd2 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 4 Mar 2025 14:49:51 -0800 Subject: [PATCH 24/30] break out the decomposition of a single groupby request into a stand-alone function --- .../cudf_polars/experimental/groupby.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 83334bda713..7f9ed01c55a 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -34,7 +34,23 @@ def decompose_groupby_reduction( schema: Schema, request: NamedExpr, ) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]: - """Decompose a groupby request.""" + """ + Decompose a groupby-aggregation request. + + Parameters + ---------- + schema + Output schema. + request + The `NamedExpr` representing the aggregation logic for a + single column. + + Returns + ------- + Tuple containing a list of `NamedExpr` for each of the + three parallel-aggregation phases: + (1) Piecewise, (2) reduction, and (3) selection + """ complex_expr_map: MutableMapping[str, Any] = {} piecewise_exprs: list[NamedExpr] = [] reduction_exprs: list[NamedExpr] = [] @@ -168,6 +184,7 @@ def _( 1, ) + # Decompose the aggregation requests into three distinct phases piecewise_exprs, reduction_exprs, selection_exprs = ( list(itertools.chain.from_iterable(x)) for x in zip( @@ -217,6 +234,7 @@ def _( ) partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count) + # Final Select phase should_broadcast: bool = False aggregated = {ne.name: ne for ne in selection_exprs} new_node = Select( From f959988db02a24503b4ba5aa620d5b457b11cc97 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 11 Mar 2025 07:52:59 -0700 Subject: [PATCH 25/30] address schema and maintain_order issues --- .../cudf_polars/experimental/groupby.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 7f9ed01c55a..bc1eba35498 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -197,8 +197,11 @@ def _( ) # Partition-wise groupby operation + pwise_schema = {k.name: k.value.dtype for k in ir.keys} | { + k.name: k.value.dtype for k in piecewise_exprs + } gb_pwise = GroupBy( - ir.schema, + pwise_schema, ir.keys, piecewise_exprs, ir.maintain_order, @@ -212,9 +215,14 @@ def _( # Add Shuffle node if necessary gb_inter: GroupBy | Shuffle = gb_pwise if post_aggregation_count > 1: + if ir.maintain_order: # pragma: no cover + raise NotImplementedError( + "maintain_order not supported for multiple output partitions." + ) + shuffle_options: dict[str, Any] = {} gb_inter = Shuffle( - ir.schema, + pwise_schema, ir.keys, shuffle_options, gb_pwise, @@ -224,7 +232,8 @@ def _( # Tree reduction if post_aggregation_count==1 # (Otherwise, this is another partition-wise op) gb_reduce = GroupBy( - ir.schema, + {k.name: k.value.dtype for k in ir.keys} + | {k.name: k.value.dtype for k in reduction_exprs}, ir.keys, reduction_exprs, ir.maintain_order, @@ -235,7 +244,6 @@ def _( partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count) # Final Select phase - should_broadcast: bool = False aggregated = {ne.name: ne for ne in selection_exprs} new_node = Select( ir.schema, @@ -244,7 +252,7 @@ def _( aggregated.get(name, NamedExpr(name, Col(dtype, name))) for name, dtype in ir.schema.items() ], - should_broadcast, + False, # noqa: FBT003 gb_reduce, ) partition_info[new_node] = PartitionInfo(count=post_aggregation_count) From bca71d6e3b98d5e64e0f9b807b154bfc1070d360 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 11 Mar 2025 09:56:57 -0700 Subject: [PATCH 26/30] use lawrences suggestions --- .../cudf_polars/experimental/groupby.py | 192 ++++++++---------- .../tests/experimental/test_groupby.py | 9 +- 2 files changed, 97 insertions(+), 104 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index bc1eba35498..e89a87fbc75 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -5,7 +5,6 @@ from __future__ import annotations import itertools -from functools import partial from typing import TYPE_CHECKING, Any import pylibcudf as plc @@ -23,124 +22,118 @@ from cudf_polars.dsl.expr import Expr from cudf_polars.dsl.ir import IR from cudf_polars.experimental.parallel import LowerIRTransformer - from cudf_polars.typing import Schema # Supported multi-partition aggregations _GB_AGG_SUPPORTED = ("sum", "count", "mean") -def decompose_groupby_reduction( - schema: Schema, - request: NamedExpr, +def combine( + *decompositions: tuple[NamedExpr, list[NamedExpr], list[NamedExpr]], ) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]: """ - Decompose a groupby-aggregation request. + Combine multiple groupby-aggregation decompositions. Parameters ---------- - schema - Output schema. - request - The `NamedExpr` representing the aggregation logic for a - single column. + decompositions + Packed sequence of `decompose` results. Returns ------- - Tuple containing a list of `NamedExpr` for each of the - three parallel-aggregation phases: - (1) Piecewise, (2) reduction, and (3) selection + Unified groupby-aggregation decomposition. """ - complex_expr_map: MutableMapping[str, Any] = {} - piecewise_exprs: list[NamedExpr] = [] - reduction_exprs: list[NamedExpr] = [] - selection_exprs: list[NamedExpr] = [] - unary_op: dict[str, Any] = {} - - name = request.name - agg: Expr = request.value - dtype = agg.dtype - agg = agg.children[0] if isinstance(agg, Cast) else agg - - if isinstance(agg, Len): - piecewise_exprs.append(request) - reduction_exprs.append( + selections, aggregations, reductions = zip(*decompositions, strict=False) + assert all(isinstance(ne, NamedExpr) for ne in selections) + return ( + list(selections), + list(itertools.chain.from_iterable(aggregations)), + list(itertools.chain.from_iterable(reductions)), + ) + + +def decompose( + name: str, expr: Expr +) -> tuple[NamedExpr, list[NamedExpr], list[NamedExpr]]: + """ + Decompose a groupby-aggregation expression. + + Parameters + ---------- + name + Output schema name. + expr + The aggregation expression for a single column. + + Returns + ------- + Tuple containing the `NamedExpr`s for each of the + three parallel-aggregation phases: (1) selection, + (2) initial aggregation, and (3) reduction. + """ + dtype = expr.dtype + expr = expr.children[0] if isinstance(expr, Cast) else expr + + unary_op: list[Any] = [] + if isinstance(expr, UnaryFunction) and expr.is_pointwise: + # TODO: Handle multiple/sequential unary ops + unary_op = [expr.name, expr.options] + expr = expr.children[0] + + def _wrap_unary(select): + # Helper function to wrap the final selection + # in a UnaryFunction (when necessary) + if unary_op: + return UnaryFunction(select.dtype, *unary_op, select) + return select + + if isinstance(expr, Len): + selection = NamedExpr(name, _wrap_unary(Col(dtype, name))) + aggregation = [NamedExpr(name, expr)] + reduction = [ NamedExpr( name, - Cast( - dtype, - Agg(dtype, "sum", None, Col(dtype, name)), - ), - ) - ) - elif isinstance(agg, (Agg, UnaryFunction)): - if ( - isinstance(agg, UnaryFunction) - and agg.is_pointwise - and isinstance(agg.children[0], Agg) - ): - # TODO: Handle sequential unary ops - unary_op = {"name": agg.name, "options": agg.options} - agg = agg.children[0] - - if agg.name not in _GB_AGG_SUPPORTED: - raise NotImplementedError( - "GroupBy does not support multiple partitions " - f"for this expression:\n{agg}" + Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))), ) - - if agg.name in ("sum", "count"): - piecewise_exprs.append(request) - reduction_exprs.append( + ] + return selection, aggregation, reduction + if isinstance(expr, Agg): + if expr.name in ("sum", "count"): + selection = NamedExpr(name, _wrap_unary(Col(dtype, name))) + aggregation = [NamedExpr(name, expr)] + reduction = [ NamedExpr( name, - Cast( - dtype, - Agg(dtype, "sum", agg.options, Col(dtype, name)), - ), + Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))), ) + ] + return selection, aggregation, reduction + elif expr.name == "mean": + (child,) = expr.children + (sum, count), aggregations, reductions = combine( + decompose(f"{name}__mean_sum", Agg(dtype, "sum", None, child)), + decompose(f"{name}__mean_count", Len(dtype)), + ) + selection = NamedExpr( + name, + _wrap_unary( + BinOp( + dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value + ) + ), + ) + return selection, aggregations, reductions + else: + raise NotImplementedError( + "GroupBy does not support multiple partitions " + f"for this aggregation type:\n{type(expr)}\n" + f"Only {_GB_AGG_SUPPORTED} are supported." ) - elif agg.name == "mean": - complex_expr_map[name] = {"mean": {}} - for sub in ["sum", "count"]: - # Partwise - tmp_name = f"{name}__{sub}" - complex_expr_map[name]["mean"][sub] = tmp_name - agg_pwise = Agg(dtype, sub, agg.options, *agg.children) - piecewise_exprs.append(NamedExpr(tmp_name, agg_pwise)) - # Tree - agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) - reduction_exprs.append(NamedExpr(tmp_name, agg_tree)) - else: + else: # pragma: no cover # Unsupported expression raise NotImplementedError( - f"GroupBy does not support multiple partitions for this expression:\n{agg}" - ) # pragma: no cover - - # Construct final selection expressions - col_expr: Col | BinOp | UnaryFunction - dtype = schema[name] - complex_expr = complex_expr_map.get(name, None) - if complex_expr is None: - col_expr = Col(dtype, name) - elif "mean" in complex_expr: - mean_cols = complex_expr["mean"] - col_expr = BinOp( - dtype, - plc.binaryop.BinaryOperator.DIV, - Col(dtype, mean_cols["sum"]), - Col(dtype, mean_cols["count"]), + f"GroupBy does not support multiple partitions for this expression:\n{expr}" ) - if unary_op: - col_expr = UnaryFunction( - dtype, - unary_op["name"], - unary_op["options"], - col_expr, - ) - selection_exprs.append(NamedExpr(name, col_expr)) - - return piecewise_exprs, reduction_exprs, selection_exprs @lower_ir_node.register(GroupBy) @@ -185,15 +178,8 @@ def _( ) # Decompose the aggregation requests into three distinct phases - piecewise_exprs, reduction_exprs, selection_exprs = ( - list(itertools.chain.from_iterable(x)) - for x in zip( - *map( - partial(decompose_groupby_reduction, ir.schema), - ir.agg_requests, - ), - strict=False, - ) + selection_exprs, piecewise_exprs, reduction_exprs = combine( + *(decompose(agg.name, agg.value) for agg in ir.agg_requests) ) # Partition-wise groupby operation diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index 11c5c432b11..6fec8db167f 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -34,7 +34,14 @@ def df(): @pytest.mark.parametrize("keys", [("y",), ("y", "z")]) def test_groupby(df, engine, op, keys): q = getattr(df.group_by(*keys), op)() - assert_gpu_result_equal(q, engine=engine, check_row_order=False) + + from cudf_polars import Translator + from cudf_polars.experimental.parallel import evaluate_dask + + ir = Translator(q._ldf.visit(), engine).translate_ir() + evaluate_dask(ir) + + # assert_gpu_result_equal(q, engine=engine, check_row_order=False) @pytest.mark.parametrize("op", ["sum", "mean", "len"]) From 7f90ded1d84f9af05ab9901d622d1a1bd96e43cd Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 11 Mar 2025 11:28:17 -0700 Subject: [PATCH 27/30] address small code-review comments --- python/cudf_polars/cudf_polars/experimental/groupby.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index e89a87fbc75..bc17fb2c86c 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -43,7 +43,7 @@ def combine( ------- Unified groupby-aggregation decomposition. """ - selections, aggregations, reductions = zip(*decompositions, strict=False) + selections, aggregations, reductions = zip(*decompositions, strict=True) assert all(isinstance(ne, NamedExpr) for ne in selections) return ( list(selections), @@ -93,6 +93,8 @@ def _wrap_unary(select): reduction = [ NamedExpr( name, + # Sum reduction may require casting. + # Do it for all cases to be safe (for now) Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))), ) ] @@ -104,6 +106,8 @@ def _wrap_unary(select): reduction = [ NamedExpr( name, + # Sum reduction may require casting. + # Do it for all cases to be safe (for now) Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))), ) ] @@ -111,6 +115,8 @@ def _wrap_unary(select): elif expr.name == "mean": (child,) = expr.children (sum, count), aggregations, reductions = combine( + # TODO: Avoid possibility of a name collision + # (even though the likelihood is small) decompose(f"{name}__mean_sum", Agg(dtype, "sum", None, child)), decompose(f"{name}__mean_count", Len(dtype)), ) From e6f32d7a900543fc3b79cedafe0b341f95232d2e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 13 Mar 2025 07:47:40 -0700 Subject: [PATCH 28/30] add back copy on init --- python/cudf_polars/cudf_polars/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index e445d7e1caf..bbbaef5b696 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -30,7 +30,7 @@ class ConfigOptions: def __init__(self, options: dict[str, Any]): self.validate(options) - self.config_options = options + self.config_options = copy.deepcopy(options) def set(self, name: str, value: Any) -> Self: """ From b3e77fed525b534196ff0482ae117114436af146 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 13 Mar 2025 07:49:43 -0700 Subject: [PATCH 29/30] fix --- python/cudf_polars/cudf_polars/dsl/translate.py | 3 ++- python/cudf_polars/cudf_polars/utils/config.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 86d1ac159ad..3c4c3bb3e0c 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy import functools import json from contextlib import AbstractContextManager, nullcontext @@ -47,7 +48,7 @@ class Translator: def __init__(self, visitor: NodeTraverser, engine: GPUEngine): self.visitor = visitor - self.config_options = config.ConfigOptions(engine.config.copy()) + self.config_options = config.ConfigOptions(copy.deepcopy(engine.config)) self.errors: list[Exception] = [] def translate_ir(self, *, n: int | None = None) -> ir.IR: diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index bbbaef5b696..e445d7e1caf 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -30,7 +30,7 @@ class ConfigOptions: def __init__(self, options: dict[str, Any]): self.validate(options) - self.config_options = copy.deepcopy(options) + self.config_options = options def set(self, name: str, value: Any) -> Self: """ From fb6f1da4d1d085a018ec75bf21ce54d2deb5e0a7 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 13 Mar 2025 07:57:20 -0700 Subject: [PATCH 30/30] add back missing test --- python/cudf_polars/tests/experimental/test_groupby.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/cudf_polars/tests/experimental/test_groupby.py b/python/cudf_polars/tests/experimental/test_groupby.py index ddd2f762d7d..c28584aaefc 100644 --- a/python/cudf_polars/tests/experimental/test_groupby.py +++ b/python/cudf_polars/tests/experimental/test_groupby.py @@ -34,14 +34,7 @@ def df(): @pytest.mark.parametrize("keys", [("y",), ("y", "z")]) def test_groupby(df, engine, op, keys): q = getattr(df.group_by(*keys), op)() - - from cudf_polars import Translator - from cudf_polars.experimental.parallel import evaluate_dask - - ir = Translator(q._ldf.visit(), engine).translate_ir() - evaluate_dask(ir) - - # assert_gpu_result_equal(q, engine=engine, check_row_order=False) + assert_gpu_result_equal(q, engine=engine, check_row_order=False) @pytest.mark.parametrize("op", ["sum", "mean", "len"])