Skip to content

Commit d38d986

Browse files
SNOW-2679277: Add support for groupby.get_group/resample/rolling in faster pandas (#3991)
1 parent c827ae5 commit d38d986

File tree

3 files changed

+280
-58
lines changed

3 files changed

+280
-58
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@
161161
- `groupby.any`
162162
- `groupby.all`
163163
- `groupby.unique`
164+
- `groupby.get_group`
165+
- `groupby.rolling`
166+
- `groupby.resample`
164167
- `to_snowflake`
165168
- `to_snowpark`
166169
- Make faster pandas disabled by default (opt-in instead of opt-out).

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5975,6 +5975,40 @@ def groupby_resample(
59755975
agg_args: Any,
59765976
agg_kwargs: dict[str, Any],
59775977
) -> "SnowflakeQueryCompiler":
5978+
"""
5979+
Wrapper around _groupby_resample_internal to be supported in faster pandas.
5980+
"""
5981+
relaxed_query_compiler = None
5982+
if self._relaxed_query_compiler is not None:
5983+
relaxed_query_compiler = (
5984+
self._relaxed_query_compiler._groupby_resample_internal(
5985+
resample_kwargs=resample_kwargs,
5986+
resample_method=resample_method,
5987+
groupby_kwargs=groupby_kwargs,
5988+
is_series=is_series,
5989+
agg_args=agg_args,
5990+
agg_kwargs=agg_kwargs,
5991+
)
5992+
)
5993+
qc = self._groupby_resample_internal(
5994+
resample_kwargs=resample_kwargs,
5995+
resample_method=resample_method,
5996+
groupby_kwargs=groupby_kwargs,
5997+
is_series=is_series,
5998+
agg_args=agg_args,
5999+
agg_kwargs=agg_kwargs,
6000+
)
6001+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
6002+
6003+
def _groupby_resample_internal(
6004+
self,
6005+
resample_kwargs: dict[str, Any],
6006+
resample_method: AggFuncType,
6007+
groupby_kwargs: dict[str, Any],
6008+
is_series: bool,
6009+
agg_args: Any,
6010+
agg_kwargs: dict[str, Any],
6011+
) -> "SnowflakeQueryCompiler":
59786012

59796013
validate_groupby_resample_supported_by_snowflake(resample_kwargs)
59806014
level = groupby_kwargs.get("level", None)
@@ -6128,6 +6162,40 @@ def groupby_rolling(
61286162
is_series: bool,
61296163
agg_args: Any,
61306164
agg_kwargs: dict[str, Any],
6165+
) -> "SnowflakeQueryCompiler":
6166+
"""
6167+
Wrapper around _groupby_rolling_internal to be supported in faster pandas.
6168+
"""
6169+
relaxed_query_compiler = None
6170+
if self._relaxed_query_compiler is not None:
6171+
relaxed_query_compiler = (
6172+
self._relaxed_query_compiler._groupby_rolling_internal(
6173+
rolling_kwargs=rolling_kwargs,
6174+
rolling_method=rolling_method,
6175+
groupby_kwargs=groupby_kwargs,
6176+
is_series=is_series,
6177+
agg_args=agg_args,
6178+
agg_kwargs=agg_kwargs,
6179+
)
6180+
)
6181+
qc = self._groupby_rolling_internal(
6182+
rolling_kwargs=rolling_kwargs,
6183+
rolling_method=rolling_method,
6184+
groupby_kwargs=groupby_kwargs,
6185+
is_series=is_series,
6186+
agg_args=agg_args,
6187+
agg_kwargs=agg_kwargs,
6188+
)
6189+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
6190+
6191+
def _groupby_rolling_internal(
6192+
self,
6193+
rolling_kwargs: dict[str, Any],
6194+
rolling_method: AggFuncType,
6195+
groupby_kwargs: dict[str, Any],
6196+
is_series: bool,
6197+
agg_args: Any,
6198+
agg_kwargs: dict[str, Any],
61316199
) -> "SnowflakeQueryCompiler":
61326200
"""
61336201
Return a rolling grouper, providing rolling functionality per group.
@@ -6548,6 +6616,43 @@ def groupby_get_group(
65486616
agg_kwargs: dict[str, Any],
65496617
drop: bool = False,
65506618
**kwargs: dict[str, Any],
6619+
) -> "SnowflakeQueryCompiler":
6620+
"""
6621+
Wrapper around _groupby_get_group_internal to be supported in faster pandas.
6622+
"""
6623+
relaxed_query_compiler = None
6624+
if self._relaxed_query_compiler is not None:
6625+
relaxed_query_compiler = (
6626+
self._relaxed_query_compiler._groupby_get_group_internal(
6627+
by=by,
6628+
axis=axis,
6629+
groupby_kwargs=groupby_kwargs,
6630+
agg_args=agg_args,
6631+
agg_kwargs=agg_kwargs,
6632+
drop=drop,
6633+
**kwargs,
6634+
)
6635+
)
6636+
qc = self._groupby_get_group_internal(
6637+
by=by,
6638+
axis=axis,
6639+
groupby_kwargs=groupby_kwargs,
6640+
agg_args=agg_args,
6641+
agg_kwargs=agg_kwargs,
6642+
drop=drop,
6643+
**kwargs,
6644+
)
6645+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
6646+
6647+
def _groupby_get_group_internal(
6648+
self,
6649+
by: Any,
6650+
axis: int,
6651+
groupby_kwargs: dict[str, Any],
6652+
agg_args: tuple[Any],
6653+
agg_kwargs: dict[str, Any],
6654+
drop: bool = False,
6655+
**kwargs: dict[str, Any],
65516656
) -> "SnowflakeQueryCompiler":
65526657
"""
65536658
Get all rows that match a given group name in the `by` column.

0 commit comments

Comments
 (0)