Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1569005: Deduplicate queries and post actions with building queries from binary plan #2090

2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Snowpark Python API Updates

#### New Features

- Added support for `snowflake.snowpark.testing.assert_dataframe_equal` that is a util function to check the equality of two Snowpark DataFrames.
- Added support for `Resampler.fillna` and `Resampler.bfill`.
- Added limited support for the `Timedelta` type, including creating `Timedelta` columns and `to_pandas`.
Expand All @@ -26,6 +27,7 @@
- Fixed a bug in `DataFrame.lineage.trace` to split the quoted feature view's name and version correctly.
- Fixed a bug in `Column.isin` that caused invalid sql generation when passed an empty list.
- Fixed a bug that fails to raise NotImplementedError while setting cell with list like item.
- Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries.

### Snowpark Local Testing Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,15 @@ def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None:
if operand.selectable.pre_actions:
if not self.pre_actions:
self.pre_actions = []
self.pre_actions.extend(operand.selectable.pre_actions)
for action in operand.selectable.pre_actions:
if action not in self.pre_actions:
self.pre_actions.append(copy(action))
if operand.selectable.post_actions:
if not self.post_actions:
self.post_actions = []
self.post_actions.extend(operand.selectable.post_actions)
for action in operand.selectable.post_actions:
if action not in self.post_actions:
self.post_actions.append(copy(action))
self._nodes.append(operand.selectable)

def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006
Expand Down
11 changes: 9 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,15 @@ def build_binary(
if post_action not in post_actions:
post_actions.append(copy.copy(post_action))
else:
merged_queries = select_left.queries[:-1] + select_right.queries[:-1]
post_actions = select_left.post_actions + select_right.post_actions
merged_queries = select_left.queries[:-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can merge the code for query merge here for the if else branch, we only need a parameter protection for referenced_ctes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referenced_ctes are only enabled when we have _query_compilation_stage_enabled=True right. This is the False branch.
What kind of protection are you thinking of exaclty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what i mean is the following code is common for both branch under

merged_queries = select_left.queries[:-1].copy()
            for query in select_right.queries[:-1]:
                if query not in merged_queries:
                    merged_queries.append(copy.copy(query))
post_actions = select_left.post_actions.copy()
            for post_action in select_right.post_actions:
                if post_action not in post_actions:
                    post_actions.append(copy.copy(post_action))

so we can merge the code under the if else branch, and we can have the following

merged_queries = select_left.queries[:-1].copy()
            for query in select_right.queries[:-1]:
                if query not in merged_queries:
                    merged_queries.append(copy.copy(query))
post_actions = select_left.post_actions.copy()
            for post_action in select_right.post_actions:
                if post_action not in post_actions:
                    post_actions.append(copy.copy(post_action))

if (
            self.session.cte_optimization_enabled
            and self.session._query_compilation_stage_enabled
        ):
     referenced_ctes.update(select_left.referenced_ctes)
       referenced_ctes.update(select_right.referenced_ctes)

for query in select_right.queries[:-1]:
if query not in merged_queries:
merged_queries.append(copy.copy(query))

post_actions = select_left.post_actions
for action in select_right.post_actions:
if action not in post_actions:
post_actions.append(copy.copy(action))

queries = merged_queries + [
Query(
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ def generate_queries(
plan_queries = get_snowflake_plan_queries(
snowflake_plan, self.resolved_with_query_block
)
queries.extend(plan_queries[PlanQueryType.QUERIES])
post_actions.extend(plan_queries[PlanQueryType.POST_ACTIONS])
for query in plan_queries[PlanQueryType.QUERIES]:
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
if query not in queries:
queries.append(query)
for action in plan_queries[PlanQueryType.POST_ACTIONS]:
if action not in post_actions:
post_actions.append(action)

return {
PlanQueryType.QUERIES: queries,
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/scala/test_snowflake_plan_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_multiple_queries(session):
Utils.drop_table(session, table_name2)


def test_execution_queries_and_queries(session):
def test_execution_queries_and_post_actions(session):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df1 = df.select("a", "b")
# create a df where cte optimization can be applied
Expand Down
17 changes: 3 additions & 14 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_unary(session, action):
lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"),
],
)
def test_binary(session, action, sql_simplifier_enabled):
def test_binary(session, action):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
check_result(session, action(df, df), expect_cte_optimized=True)

Expand All @@ -134,19 +134,8 @@ def test_binary(session, action, sql_simplifier_enabled):
check_result(session, df3, expect_cte_optimized=True)
plan_queries = df3.queries
# check the number of queries
binary_build_used = (not sql_simplifier_enabled) or (
"JOIN" in plan_queries["queries"][-1]
)
# TODO (SNOW-1569005): Deduplicate queries during plan build for binary operators. This currently
# is only fixed when _query_compilation_stage_enabled for build_binary.
if session._query_compilation_stage_enabled and binary_build_used:
num_queries = 3
num_post_actions = 1
else:
num_queries = 5
num_post_actions = 2
assert len(plan_queries["queries"]) == num_queries
assert len(plan_queries["post_actions"]) == num_post_actions
assert len(plan_queries["queries"]) == 3
assert len(plan_queries["post_actions"]) == 1


@pytest.mark.parametrize(
Expand Down
Loading