Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture AST for DataFrame.pivot and DataFrame.unpivot #2287

Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5adc406
Catch up with Python formatting
sfc-gh-oplaton Jul 29, 2024
df29bef
Update ast_pb2.py
sfc-gh-oplaton Jul 29, 2024
7b15fdd
Merge branch 'ls-SNOW-1491199-merge-phase0-server-side' into oplaton/…
sfc-gh-oplaton Aug 23, 2024
d93b477
Reformat the repository
sfc-gh-oplaton Sep 4, 2024
6da3133
Update everything needed to make tests pass
sfc-gh-oplaton Sep 4, 2024
9487fcf
Merge branch 'oplaton/tmp-doctests' into oplaton/SNOW-1491297-pivot
sfc-gh-oplaton Sep 4, 2024
6e6c529
Disable df_random_split.test temporarily
sfc-gh-oplaton Sep 4, 2024
7d36377
Merge branch 'oplaton/tmp-doctests' into oplaton/SNOW-1491297-pivot
sfc-gh-oplaton Sep 4, 2024
f53a328
Update ast_pb2
sfc-gh-oplaton Sep 5, 2024
b1fbc7e
Disable the base64 expectation tests temporarily
sfc-gh-oplaton Sep 5, 2024
6fe0265
Add an expectation test for DataFrame.pivot
sfc-gh-oplaton Sep 6, 2024
6806d5c
Collect the AST for DataFrame.pivot
sfc-gh-oplaton Sep 6, 2024
9c9c1f2
Merge branch 'ls-SNOW-1491199-merge-phase0-server-side' into oplaton/…
sfc-gh-oplaton Sep 6, 2024
6224fd7
Update ast_pb2
sfc-gh-oplaton Sep 6, 2024
a288f6b
Update ast_pb2
sfc-gh-oplaton Sep 6, 2024
10e0608
Update the pivot expectations
sfc-gh-oplaton Sep 6, 2024
857878f
Add unpivot expectation test
sfc-gh-oplaton Sep 6, 2024
6156625
Update ast_pb2
sfc-gh-oplaton Sep 9, 2024
e25046c
Update ast_pb2
sfc-gh-oplaton Sep 11, 2024
da81d62
Update the encoded AST for the pivot test
sfc-gh-oplaton Sep 11, 2024
20dc168
Unpivot implementation
sfc-gh-oplaton Sep 11, 2024
1bcfab7
Don't pass ast_stmt to _with_plan
sfc-gh-oplaton Sep 11, 2024
3b4e118
Merge branch 'ls-SNOW-1491199-merge-phase0-server-side' into oplaton/…
sfc-gh-oplaton Sep 11, 2024
b9ab82e
Update ast_pb2
sfc-gh-oplaton Sep 11, 2024
189d0fd
Disable the unpivot test temporarily
sfc-gh-oplaton Sep 12, 2024
e1ca763
Update ast_pb2
sfc-gh-oplaton Sep 12, 2024
490cca4
Merge branch 'ls-SNOW-1491199-merge-phase0-server-side' into oplaton/…
sfc-gh-oplaton Sep 12, 2024
01e4800
Update ast_pb2
sfc-gh-oplaton Sep 12, 2024
42bc70c
Update expectations for the DataFrame.pivot test
sfc-gh-oplaton Sep 12, 2024
9c1d422
Undo the random_split rename
sfc-gh-oplaton Sep 12, 2024
f562e0e
Simplify result declaration and initialization
sfc-gh-oplaton Sep 12, 2024
99b9a37
Add a comment about re-enabling the disabled unpivot test
sfc-gh-oplaton Sep 13, 2024
709a590
Emit AST conditionally from unpivot
sfc-gh-oplaton Sep 13, 2024
d222050
Rename the ast_stmt parameter in the RelationalGroupedDataFrame initi…
sfc-gh-oplaton Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,375 changes: 687 additions & 688 deletions src/snowflake/snowpark/_internal/proto/ast_pb2.py

Large diffs are not rendered by default.

26 changes: 23 additions & 3 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
build_expr_from_snowpark_column_or_col_name,
build_expr_from_snowpark_column_or_sql_str,
build_expr_from_snowpark_column_or_table_fn,
build_proto_from_pivot_values,
fill_ast_for_column,
with_src_position,
)
Expand Down Expand Up @@ -2192,7 +2193,12 @@ def pivot(
"""

if _emit_ast:
sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("TODO SNOW-1491297: Add coverage for pivot.")
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_pivot, stmt)
self.set_ast_ref(ast.df)
build_expr_from_snowpark_column_or_col_name(ast.pivot_col, pivot_col)
build_proto_from_pivot_values(ast.values, values)
build_expr_from_python_val(ast.default_on_null, default_on_null)

target_df, pc, pivot_values, default_on_null = prepare_pivot_arguments(
self, "DataFrame.pivot", pivot_col, values, default_on_null
Expand All @@ -2204,6 +2210,7 @@ def pivot(
snowflake.snowpark.relational_grouped_dataframe._PivotType(
pc[0], pivot_values, default_on_null
sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
),
ast_stmt=stmt,
)

@df_api_usage
Expand Down Expand Up @@ -2237,19 +2244,32 @@ def unpivot(
---------------------------------------------
<BLANKLINE>
"""
# AST.
sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_unpivot, stmt)
self.set_ast_ref(ast.df)
ast.value_column = value_column
ast.name_column = name_column
for c in column_list:
build_expr_from_snowpark_column_or_col_name(ast.column_list.add(), c)

column_exprs = self._convert_cols_to_exprs("unpivot()", column_list)
unpivot_plan = Unpivot(value_column, name_column, column_exprs, self._plan)

df: DataFrame
sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
if self._select_statement:
return self._with_plan(
df = self._with_plan(
SelectStatement(
from_=SelectSnowflakePlan(
unpivot_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
)
)
return self._with_plan(unpivot_plan)
else:
df = self._with_plan(unpivot_plan)
df._ast_id = stmt.var_id.bitfield1
sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
return df

@df_api_usage
def limit(
Expand Down
38 changes: 38 additions & 0 deletions tests/ast/data/DataFrame.pivot.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## TEST CASE

df = session.create_dataframe(
[
(1, 10000, "JAN"),
(1, 400, "JAN"),
(2, 4500, "JAN"),
(2, 35000, "JAN"),
(1, 5000, "FEB"),
(1, 3000, "FEB"),
(2, 200, "FEB"),
],
schema=["k", "t", "mo"],
)

df1 = df.pivot("mo", ["JAN", "FEB"]).sum("t").sort("k")

df2 = df.pivot("mo", values=["JAN", "FEB"], default_on_null="Nothing").sum("t").sort("k")

## EXPECTED UNPARSER OUTPUT

df = session.create_dataframe([(1, 10000, "JAN"), (1, 400, "JAN"), (2, 4500, "JAN"), (2, 35000, "JAN"), (1, 5000, "FEB"), (1, 3000, "FEB"), (2, 200, "FEB")], schema=["k", "t", "mo"])

df1 = df.pivot("mo", values=["JAN", "FEB"], default_on_null=None)

df1 = df1.sum("t")

df1 = df1.sort("k")

df2 = df.pivot("mo", values=["JAN", "FEB"], default_on_null="Nothing")

df2 = df2.sum("t")

df2 = df2.sort("k")

## EXPECTED ENCODED AST

CrcICrQICqUI0gWhCAr0BwrxBwqNAZoMiQEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEiGSAh4KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEAESIpICHwoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoQkE4SJPILIQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoSA0pBTgqNAZoMiQEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEiGSAh4KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEAESIpICHwoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoQkAMSJPILIQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoSA0pBTgqNAZoMiQEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEiGSAh4KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEAISIpICHwoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoQlCMSJPILIQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoSA0pBTgqOAZoMigEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEiGSAh4KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEAISI5ICIAoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoQuJECEiTyCyEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEgNKQU4KjQGaDIkBChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhIhkgIeChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhABEiKSAh8KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEIgnEiTyCyEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEgNGRUIKjQGaDIkBChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhIhkgIeChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhABEiKSAh8KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaELgXEiTyCyEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEgNGRUIKjQGaDIkBChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhIhkgIeChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoWhACEiKSAh8KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEMgBEiTyCyEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShaEgNGRUISDAoKCgFrCgF0CgJtbxoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKFoSBAoCZGYYASICCAEK8wEK8AEK4AGiCNwBCh/iAhwKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgeKAgQKAggBGiPyCyAKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgJtbyIaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGcqbxJtCmuiAmgKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEiTyCyEKGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgNKQU4SJPILIQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGcSA0ZFQhIFCgNkZjEYAiICCAIKZQpjClT6ClEKA3N1bRImCiLyCx8KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgF0EAEaBlIECgIIAiIaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGcSBQoDZGYxGAMiAggDCl8KXQpO+ghLEiLyCx8KGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgFrGAEiB4oCBAoCCAMqGhoWU1JDX1BPU0lUSU9OX1RFU1RfTU9ERShnEgUKA2RmMRgEIgIIBAr8AQr5AQrpAaII5QEKKPILJQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSB05vdGhpbmcSB4oCBAoCCAEaI/ILIAoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSAm1vIhoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoaSpvEm0Ka6ICaAoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSJPILIQoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSA0pBThIk8gshChoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoaRIDRkVCEgUKA2RmMhgFIgIIBQplCmMKVPoKUQoDc3VtEiYKIvILHwoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSAXQQARoGUgQKAggFIhoaFlNSQ19QT1NJVElPTl9URVNUX01PREUoaRIFCgNkZjIYBiICCAYKXwpdCk76CEsSIvILHwoaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSAWsYASIHigIECgIIBioaGhZTUkNfUE9TSVRJT05fVEVTVF9NT0RFKGkSBQoDZGYyGAciAggHEAEaERIPCg0KBWZpbmFsEAMYCSATIgQQARgV
12 changes: 12 additions & 0 deletions tests/ast/data/DataFrame.unpivot.test.DISABLED
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## TEST CASE

sfc-gh-oplaton marked this conversation as resolved.
Show resolved Hide resolved
df = session.create_dataframe(
[(1, "electronics", 100, 200), (2, "clothes", 100, 300)],
schema=["empid", "dept", "jan", "feb"],
)
df = df.unpivot("sales", "month", ["jan", "feb"]).sort("empid")


## EXPECTED UNPARSER OUTPUT

## EXPECTED ENCODED AST
Loading