-
Notifications
You must be signed in to change notification settings - Fork 926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent pylibcudf
serialization in cudf-polars
#17449
base: branch-25.04
Are you sure you want to change the base?
Changes from 11 commits
eb4a2ff
165e68c
99a5d12
3e14ec9
2ba7ed1
4eba56c
ef36d42
5635388
bba8b3f
2d37c08
54a9cd6
faf42d0
7183bc8
bfaf41e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
|
||
|
||
class Agg(Expr): | ||
__slots__ = ("name", "op", "options", "request") | ||
__slots__ = ("name", "options", "request") | ||
_non_child = ("dtype", "name", "options") | ||
|
||
def __init__( | ||
|
@@ -46,58 +46,11 @@ def __init__( | |
raise NotImplementedError( | ||
f"Unsupported aggregation {name=}" | ||
) # pragma: no cover; all valid aggs are supported | ||
# TODO: nan handling in groupby case | ||
if name == "min": | ||
req = plc.aggregation.min() | ||
elif name == "max": | ||
req = plc.aggregation.max() | ||
elif name == "median": | ||
req = plc.aggregation.median() | ||
elif name == "n_unique": | ||
# TODO: datatype of result | ||
req = plc.aggregation.nunique(null_handling=plc.types.NullPolicy.INCLUDE) | ||
elif name == "first" or name == "last": | ||
req = None | ||
elif name == "mean": | ||
req = plc.aggregation.mean() | ||
elif name == "sum": | ||
req = plc.aggregation.sum() | ||
elif name == "std": | ||
# TODO: handle nans | ||
req = plc.aggregation.std(ddof=options) | ||
elif name == "var": | ||
# TODO: handle nans | ||
req = plc.aggregation.variance(ddof=options) | ||
elif name == "count": | ||
req = plc.aggregation.count( | ||
null_handling=plc.types.NullPolicy.EXCLUDE | ||
if not options | ||
else plc.types.NullPolicy.INCLUDE | ||
) | ||
elif name == "quantile": | ||
if name == "quantile": | ||
_, quantile = self.children | ||
if not isinstance(quantile, Literal): | ||
raise NotImplementedError("Only support literal quantile values") | ||
req = plc.aggregation.quantile( | ||
quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options] | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, {name=} is incorrectly listed in _SUPPORTED" | ||
) # pragma: no cover | ||
self.request = req | ||
op = getattr(self, f"_{name}", None) | ||
if op is None: | ||
op = partial(self._reduce, request=req) | ||
elif name in {"min", "max"}: | ||
op = partial(op, propagate_nans=options) | ||
elif name in {"count", "sum", "first", "last"}: | ||
pass | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, supported agg {name=} has no implementation" | ||
) # pragma: no cover | ||
self.op = op | ||
self.request = None | ||
|
||
_SUPPORTED: ClassVar[frozenset[str]] = frozenset( | ||
[ | ||
|
@@ -124,6 +77,46 @@ def __init__( | |
"linear": plc.types.Interpolation.LINEAR, | ||
} | ||
|
||
def _fill_request(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can just define @property
def request(self):
...
`` There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Done in faf42d0 . |
||
if self.request is None: | ||
# TODO: nan handling in groupby case | ||
if self.name == "min": | ||
req = plc.aggregation.min() | ||
elif self.name == "max": | ||
req = plc.aggregation.max() | ||
elif self.name == "median": | ||
req = plc.aggregation.median() | ||
elif self.name == "n_unique": | ||
# TODO: datatype of result | ||
req = plc.aggregation.nunique( | ||
null_handling=plc.types.NullPolicy.INCLUDE | ||
) | ||
elif self.name == "first" or self.name == "last": | ||
req = None | ||
elif self.name == "mean": | ||
req = plc.aggregation.mean() | ||
elif self.name == "sum": | ||
req = plc.aggregation.sum() | ||
elif self.name == "std": | ||
# TODO: handle nans | ||
req = plc.aggregation.std(ddof=self.options) | ||
elif self.name == "var": | ||
# TODO: handle nans | ||
req = plc.aggregation.variance(ddof=self.options) | ||
elif self.name == "count": | ||
req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE) | ||
elif self.name == "quantile": | ||
_, quantile = self.children | ||
req = plc.aggregation.quantile( | ||
quantiles=[quantile.value.as_py()], | ||
interp=Agg.interp_mapping[self.options], | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, {self.name=} is incorrectly listed in _SUPPORTED" | ||
) # pragma: no cover | ||
self.request = req | ||
|
||
def collect_agg(self, *, depth: int) -> AggInfo: | ||
"""Collect information about aggregations in groupbys.""" | ||
if depth >= 1: | ||
|
@@ -134,6 +127,7 @@ def collect_agg(self, *, depth: int) -> AggInfo: | |
raise NotImplementedError("Nan propagation in groupby for min/max") | ||
(child,) = self.children | ||
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests | ||
self._fill_request() | ||
request = self.request | ||
# These are handled specially here because we don't set up the | ||
# request for the whole-frame agg because we can avoid a | ||
|
@@ -240,7 +234,21 @@ def do_evaluate( | |
f"Agg in context {context}" | ||
) # pragma: no cover; unreachable | ||
|
||
self._fill_request() | ||
|
||
op = getattr(self, f"_{self.name}", None) | ||
if op is None: | ||
op = partial(self._reduce, request=self.request) | ||
elif self.name in {"min", "max"}: | ||
op = partial(op, propagate_nans=self.options) | ||
elif self.name in {"count", "sum", "first", "last"}: | ||
pass | ||
else: | ||
raise NotImplementedError( | ||
f"Unreachable, supported agg {self.name=} has no implementation" | ||
) # pragma: no cover | ||
|
||
# Aggregations like quantiles may have additional children that were | ||
# preprocessed into pylibcudf requests. | ||
child = self.children[0] | ||
return self.op(child.evaluate(df, context=context, mapping=mapping)) | ||
return op(child.evaluate(df, context=context, mapping=mapping)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -883,13 +883,11 @@ def __init__( | |
raise NotImplementedError("dynamic group by") | ||
if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): | ||
raise NotImplementedError("Nested aggregations in groupby") | ||
self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] | ||
self._non_child_args = ( | ||
self.keys, | ||
self.agg_requests, | ||
maintain_order, | ||
options, | ||
self.agg_infos, | ||
Comment on lines
899
to
-892
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also found that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, do you see any tests we could rely immediately upon for this case? |
||
) | ||
|
||
@staticmethod | ||
|
@@ -927,7 +925,6 @@ def do_evaluate( | |
agg_requests: Sequence[expr.NamedExpr], | ||
maintain_order: bool, # noqa: FBT001 | ||
options: Any, | ||
agg_infos: Sequence[expr.AggInfo], | ||
df: DataFrame, | ||
): | ||
"""Evaluate and return a dataframe.""" | ||
|
@@ -947,6 +944,7 @@ def do_evaluate( | |
# TODO: uniquify | ||
requests = [] | ||
replacements: list[expr.Expr] = [] | ||
agg_infos = [req.collect_agg(depth=0) for req in agg_requests] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely a good way to avoid |
||
for info in agg_infos: | ||
for pre_eval, req, rep in info.requests: | ||
if pre_eval is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to somehow validate that the
name
is supported within__init__
so that we catch a problem at translation time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should already happen in https://github.com/rapidsai/cudf/pull/17449/files/54a9cd6b8199bc3c0b89dcfaa2bb41e87c48547e#diff-38ad8c29ff55c4194a29a45f2a003e8219f7064d0ba9d552f49a866009eaa920L45-L48, or am I missing something else?