-
Notifications
You must be signed in to change notification settings - Fork 921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent pylibcudf
serialization in cudf-polars
#17449
base: branch-25.02
Are you sure you want to change the base?
Changes from 4 commits
eb4a2ff
165e68c
99a5d12
3e14ec9
2ba7ed1
4eba56c
ef36d42
5635388
bba8b3f
2d37c08
54a9cd6
faf42d0
7183bc8
bfaf41e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
|
||
|
||
class Agg(Expr): | ||
__slots__ = ("name", "options", "op", "request") | ||
__slots__ = ("name", "options", "request") | ||
_non_child = ("dtype", "name", "options") | ||
|
||
def __init__( | ||
|
@@ -45,54 +45,11 @@ def __init__( | |
raise NotImplementedError( | ||
f"Unsupported aggregation {name=}" | ||
) # pragma: no cover; all valid aggs are supported | ||
# TODO: nan handling in groupby case | ||
if name == "min": | ||
req = plc.aggregation.min() | ||
elif name == "max": | ||
req = plc.aggregation.max() | ||
elif name == "median": | ||
req = plc.aggregation.median() | ||
elif name == "n_unique": | ||
# TODO: datatype of result | ||
req = plc.aggregation.nunique(null_handling=plc.types.NullPolicy.INCLUDE) | ||
elif name == "first" or name == "last": | ||
req = None | ||
elif name == "mean": | ||
req = plc.aggregation.mean() | ||
elif name == "sum": | ||
req = plc.aggregation.sum() | ||
elif name == "std": | ||
# TODO: handle nans | ||
req = plc.aggregation.std(ddof=options) | ||
elif name == "var": | ||
# TODO: handle nans | ||
req = plc.aggregation.variance(ddof=options) | ||
elif name == "count": | ||
req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE) | ||
elif name == "quantile": | ||
_, quantile = self.children | ||
if not isinstance(quantile, Literal): | ||
raise NotImplementedError("Only support literal quantile values") | ||
req = plc.aggregation.quantile( | ||
quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options] | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, {name=} is incorrectly listed in _SUPPORTED" | ||
) # pragma: no cover | ||
self.request = req | ||
op = getattr(self, f"_{name}", None) | ||
if op is None: | ||
op = partial(self._reduce, request=req) | ||
elif name in {"min", "max"}: | ||
op = partial(op, propagate_nans=options) | ||
elif name in {"count", "first", "last"}: | ||
pass | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, supported agg {name=} has no implementation" | ||
) # pragma: no cover | ||
self.op = op | ||
self.request = None | ||
|
||
_SUPPORTED: ClassVar[frozenset[str]] = frozenset( | ||
[ | ||
|
@@ -119,6 +76,46 @@ def __init__( | |
"linear": plc.types.Interpolation.LINEAR, | ||
} | ||
|
||
def _fill_request(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can just define @property
def request(self):
...
`` There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Done in faf42d0 . |
||
if self.request is None: | ||
# TODO: nan handling in groupby case | ||
if self.name == "min": | ||
req = plc.aggregation.min() | ||
elif self.name == "max": | ||
req = plc.aggregation.max() | ||
elif self.name == "median": | ||
req = plc.aggregation.median() | ||
elif self.name == "n_unique": | ||
# TODO: datatype of result | ||
req = plc.aggregation.nunique( | ||
null_handling=plc.types.NullPolicy.INCLUDE | ||
) | ||
elif self.name == "first" or self.name == "last": | ||
req = None | ||
elif self.name == "mean": | ||
req = plc.aggregation.mean() | ||
elif self.name == "sum": | ||
req = plc.aggregation.sum() | ||
elif self.name == "std": | ||
# TODO: handle nans | ||
req = plc.aggregation.std(ddof=self.options) | ||
elif self.name == "var": | ||
# TODO: handle nans | ||
req = plc.aggregation.variance(ddof=self.options) | ||
elif self.name == "count": | ||
req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE) | ||
elif self.name == "quantile": | ||
_, quantile = self.children | ||
req = plc.aggregation.quantile( | ||
quantiles=[quantile.value.as_py()], | ||
interp=Agg.interp_mapping[self.options], | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, {self.name=} is incorrectly listed in _SUPPORTED" | ||
) # pragma: no cover | ||
self.request = req | ||
|
||
def collect_agg(self, *, depth: int) -> AggInfo: | ||
"""Collect information about aggregations in groupbys.""" | ||
if depth >= 1: | ||
|
@@ -129,6 +126,7 @@ def collect_agg(self, *, depth: int) -> AggInfo: | |
raise NotImplementedError("Nan propagation in groupby for min/max") | ||
(child,) = self.children | ||
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests | ||
self._fill_request() | ||
request = self.request | ||
# These are handled specially here because we don't set up the | ||
# request for the whole-frame agg because we can avoid a | ||
|
@@ -223,7 +221,21 @@ def do_evaluate( | |
f"Agg in context {context}" | ||
) # pragma: no cover; unreachable | ||
|
||
self._fill_request() | ||
|
||
op = getattr(self, f"_{self.name}", None) | ||
if op is None: | ||
op = partial(self._reduce, request=self.request) | ||
elif self.name in {"min", "max"}: | ||
op = partial(op, propagate_nans=self.options) | ||
elif self.name in {"count", "first", "last"}: | ||
pass | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, supported agg {self.name=} has no implementation" | ||
) # pragma: no cover | ||
|
||
# Aggregations like quantiles may have additional children that were | ||
# preprocessed into pylibcudf requests. | ||
child = self.children[0] | ||
return self.op(child.evaluate(df, context=context, mapping=mapping)) | ||
return op(child.evaluate(df, context=context, mapping=mapping)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,18 +6,22 @@ | |
|
||
from __future__ import annotations | ||
|
||
from enum import IntEnum, auto | ||
from typing import TYPE_CHECKING, ClassVar | ||
|
||
from polars.polars import _expr_nodes as pl_expr | ||
|
||
import pylibcudf as plc | ||
from pylibcudf import expressions as plc_expr | ||
|
||
from cudf_polars.containers import Column | ||
from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Mapping | ||
|
||
from typing_extensions import Self | ||
|
||
from cudf_polars.containers import DataFrame | ||
|
||
__all__ = ["BinOp"] | ||
|
@@ -27,10 +31,90 @@ class BinOp(Expr): | |
__slots__ = ("op",) | ||
_non_child = ("dtype", "op") | ||
|
||
class Operator(IntEnum): | ||
"""Internal and picklable representation of pylibcudf's `BinaryOperator`.""" | ||
|
||
ADD = auto() | ||
ATAN2 = auto() | ||
BITWISE_AND = auto() | ||
BITWISE_OR = auto() | ||
BITWISE_XOR = auto() | ||
DIV = auto() | ||
EQUAL = auto() | ||
FLOOR_DIV = auto() | ||
GENERIC_BINARY = auto() | ||
GREATER = auto() | ||
GREATER_EQUAL = auto() | ||
INT_POW = auto() | ||
INVALID_BINARY = auto() | ||
LESS = auto() | ||
LESS_EQUAL = auto() | ||
LOGICAL_AND = auto() | ||
LOGICAL_OR = auto() | ||
LOG_BASE = auto() | ||
MOD = auto() | ||
MUL = auto() | ||
NOT_EQUAL = auto() | ||
NULL_EQUALS = auto() | ||
NULL_LOGICAL_AND = auto() | ||
NULL_LOGICAL_OR = auto() | ||
NULL_MAX = auto() | ||
NULL_MIN = auto() | ||
NULL_NOT_EQUALS = auto() | ||
PMOD = auto() | ||
POW = auto() | ||
PYMOD = auto() | ||
SHIFT_LEFT = auto() | ||
SHIFT_RIGHT = auto() | ||
SHIFT_RIGHT_UNSIGNED = auto() | ||
SUB = auto() | ||
TRUE_DIV = auto() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What?
Works fine. |
||
|
||
@classmethod | ||
def from_polars(cls, obj: pl_expr.Operator) -> BinOp.Operator: | ||
"""Convert from polars' `Operator`.""" | ||
mapping: dict[pl_expr.Operator, BinOp.Operator] = { | ||
pl_expr.Operator.Eq: BinOp.Operator.EQUAL, | ||
pl_expr.Operator.EqValidity: BinOp.Operator.NULL_EQUALS, | ||
pl_expr.Operator.NotEq: BinOp.Operator.NOT_EQUAL, | ||
pl_expr.Operator.NotEqValidity: BinOp.Operator.NULL_NOT_EQUALS, | ||
pl_expr.Operator.Lt: BinOp.Operator.LESS, | ||
pl_expr.Operator.LtEq: BinOp.Operator.LESS_EQUAL, | ||
pl_expr.Operator.Gt: BinOp.Operator.GREATER, | ||
pl_expr.Operator.GtEq: BinOp.Operator.GREATER_EQUAL, | ||
pl_expr.Operator.Plus: BinOp.Operator.ADD, | ||
pl_expr.Operator.Minus: BinOp.Operator.SUB, | ||
pl_expr.Operator.Multiply: BinOp.Operator.MUL, | ||
pl_expr.Operator.Divide: BinOp.Operator.DIV, | ||
pl_expr.Operator.TrueDivide: BinOp.Operator.TRUE_DIV, | ||
pl_expr.Operator.FloorDivide: BinOp.Operator.FLOOR_DIV, | ||
pl_expr.Operator.Modulus: BinOp.Operator.PYMOD, | ||
pl_expr.Operator.And: BinOp.Operator.BITWISE_AND, | ||
pl_expr.Operator.Or: BinOp.Operator.BITWISE_OR, | ||
pl_expr.Operator.Xor: BinOp.Operator.BITWISE_XOR, | ||
pl_expr.Operator.LogicalAnd: BinOp.Operator.LOGICAL_AND, | ||
pl_expr.Operator.LogicalOr: BinOp.Operator.LOGICAL_OR, | ||
} | ||
|
||
return mapping[obj] | ||
|
||
@classmethod | ||
def to_pylibcudf(cls, obj: Self) -> plc.binaryop.BinaryOperator: | ||
"""Convert to pylibcudf's `BinaryOperator`.""" | ||
return getattr(plc.binaryop.BinaryOperator, obj.name) | ||
|
||
@classmethod | ||
def to_pylibcudf_expr(cls, obj: Self) -> plc.binaryop.BinaryOperator: | ||
"""Convert to pylibcudf's `ASTOperator`.""" | ||
if obj is BinOp.Operator.NULL_EQUALS: | ||
# Name mismatch in pylibcudf's `BinaryOperator` and `ASTOperator`. | ||
return plc_expr.ASTOperator.NULL_EQUAL | ||
return getattr(plc_expr.ASTOperator, obj.name) | ||
|
||
def __init__( | ||
self, | ||
dtype: plc.DataType, | ||
op: plc.binaryop.BinaryOperator, | ||
op: BinOp.Operator, | ||
left: Expr, | ||
right: Expr, | ||
) -> None: | ||
|
@@ -43,44 +127,19 @@ def __init__( | |
self.op = op | ||
self.children = (left, right) | ||
if not plc.binaryop.is_supported_operation( | ||
self.dtype, left.dtype, right.dtype, op | ||
self.dtype, left.dtype, right.dtype, BinOp.Operator.to_pylibcudf(op) | ||
): | ||
raise NotImplementedError( | ||
f"Operation {op.name} not supported " | ||
f"for types {left.dtype.id().name} and {right.dtype.id().name} " | ||
f"with output type {self.dtype.id().name}" | ||
) | ||
|
||
_BOOL_KLEENE_MAPPING: ClassVar[ | ||
dict[plc.binaryop.BinaryOperator, plc.binaryop.BinaryOperator] | ||
] = { | ||
plc.binaryop.BinaryOperator.BITWISE_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND, | ||
plc.binaryop.BinaryOperator.BITWISE_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR, | ||
plc.binaryop.BinaryOperator.LOGICAL_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND, | ||
plc.binaryop.BinaryOperator.LOGICAL_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR, | ||
} | ||
|
||
_MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { | ||
pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL, | ||
pl_expr.Operator.EqValidity: plc.binaryop.BinaryOperator.NULL_EQUALS, | ||
pl_expr.Operator.NotEq: plc.binaryop.BinaryOperator.NOT_EQUAL, | ||
pl_expr.Operator.NotEqValidity: plc.binaryop.BinaryOperator.NULL_NOT_EQUALS, | ||
pl_expr.Operator.Lt: plc.binaryop.BinaryOperator.LESS, | ||
pl_expr.Operator.LtEq: plc.binaryop.BinaryOperator.LESS_EQUAL, | ||
pl_expr.Operator.Gt: plc.binaryop.BinaryOperator.GREATER, | ||
pl_expr.Operator.GtEq: plc.binaryop.BinaryOperator.GREATER_EQUAL, | ||
pl_expr.Operator.Plus: plc.binaryop.BinaryOperator.ADD, | ||
pl_expr.Operator.Minus: plc.binaryop.BinaryOperator.SUB, | ||
pl_expr.Operator.Multiply: plc.binaryop.BinaryOperator.MUL, | ||
pl_expr.Operator.Divide: plc.binaryop.BinaryOperator.DIV, | ||
pl_expr.Operator.TrueDivide: plc.binaryop.BinaryOperator.TRUE_DIV, | ||
pl_expr.Operator.FloorDivide: plc.binaryop.BinaryOperator.FLOOR_DIV, | ||
pl_expr.Operator.Modulus: plc.binaryop.BinaryOperator.PYMOD, | ||
pl_expr.Operator.And: plc.binaryop.BinaryOperator.BITWISE_AND, | ||
pl_expr.Operator.Or: plc.binaryop.BinaryOperator.BITWISE_OR, | ||
pl_expr.Operator.Xor: plc.binaryop.BinaryOperator.BITWISE_XOR, | ||
pl_expr.Operator.LogicalAnd: plc.binaryop.BinaryOperator.LOGICAL_AND, | ||
pl_expr.Operator.LogicalOr: plc.binaryop.BinaryOperator.LOGICAL_OR, | ||
_BOOL_KLEENE_MAPPING: ClassVar[dict[Operator, Operator]] = { | ||
Operator.BITWISE_AND: Operator.NULL_LOGICAL_AND, | ||
Operator.BITWISE_OR: Operator.NULL_LOGICAL_OR, | ||
Operator.LOGICAL_AND: Operator.NULL_LOGICAL_AND, | ||
Operator.LOGICAL_OR: Operator.NULL_LOGICAL_OR, | ||
} | ||
|
||
def do_evaluate( | ||
|
@@ -103,7 +162,9 @@ def do_evaluate( | |
elif right.is_scalar: | ||
rop = right.obj_scalar | ||
return Column( | ||
plc.binaryop.binary_operation(lop, rop, self.op, self.dtype), | ||
plc.binaryop.binary_operation( | ||
lop, rop, BinOp.Operator.to_pylibcudf(self.op), self.dtype | ||
), | ||
) | ||
|
||
def collect_agg(self, *, depth: int) -> AggInfo: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,7 +74,7 @@ def _validate_input(self): | |
) | ||
pattern = self.children[1].value.as_py() | ||
try: | ||
self._regex_program = plc.strings.regex_program.RegexProgram.create( | ||
plc.strings.regex_program.RegexProgram.create( | ||
pattern, | ||
flags=plc.strings.regex_flags.RegexFlags.DEFAULT, | ||
) | ||
Comment on lines
-141
to
144
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kind of hate this, because we have to do this twice now, once for control-flow here, and then once to actually use the program. |
||
|
@@ -154,6 +154,12 @@ def do_evaluate( | |
) | ||
return Column(plc.strings.find.contains(column.obj, pattern)) | ||
else: | ||
assert isinstance(arg, Literal) | ||
pattern = arg.value.as_py() | ||
self._regex_program = plc.strings.regex_program.RegexProgram.create( | ||
pattern, | ||
flags=plc.strings.regex_flags.RegexFlags.DEFAULT, | ||
) | ||
return Column( | ||
plc.strings.contains.contains_re(column.obj, self._regex_program) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to somehow validate that the
name
is supported within__init__
so that we catch a problem at translation time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should already happen in https://github.com/rapidsai/cudf/pull/17449/files/54a9cd6b8199bc3c0b89dcfaa2bb41e87c48547e#diff-38ad8c29ff55c4194a29a45f2a003e8219f7064d0ba9d552f49a866009eaa920L45-L48, or am I missing something else?