Skip to content

Commit

Permalink
SNOW-1569005: Deduplicate queries and post actions with building quer…
Browse files Browse the repository at this point in the history
…ies from binary plan (#2090)
  • Loading branch information
sfc-gh-aalam authored Aug 20, 2024
1 parent 695cd14 commit eba49e7
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Bug Fixes

- Fixed a bug in `session.read.csv` that caused an error when setting `PARSE_HEADER = True` in an externally defined file format.
- Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries.

### Snowpark Local Testing Updates

Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,15 @@ def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None:
if operand.selectable.pre_actions:
if not self.pre_actions:
self.pre_actions = []
self.pre_actions.extend(operand.selectable.pre_actions)
for action in operand.selectable.pre_actions:
if action not in self.pre_actions:
self.pre_actions.append(copy(action))
if operand.selectable.post_actions:
if not self.post_actions:
self.post_actions = []
self.post_actions.extend(operand.selectable.post_actions)
for action in operand.selectable.post_actions:
if action not in self.post_actions:
self.post_actions.append(copy(action))
self._nodes.append(operand.selectable)

def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006
Expand Down
33 changes: 15 additions & 18 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,32 +614,29 @@ def build_binary(
}
api_calls = [*select_left.api_calls, *select_right.api_calls]

# Need to do a deduplication to avoid repeated query.
merged_queries = select_left.queries[:-1].copy()
for query in select_right.queries[:-1]:
if query not in merged_queries:
merged_queries.append(copy.copy(query))

post_actions = select_left.post_actions.copy()
for post_action in select_right.post_actions:
if post_action not in post_actions:
post_actions.append(copy.copy(post_action))

referenced_ctes: Set[str] = set()
if (
self.session.cte_optimization_enabled
and self.session._query_compilation_stage_enabled
):
# When the cte optimization and the new compilation stage is enabled, the
# queries, referred cte tables, and post actions propagated from
# left and right can have duplicated queries if there is a common CTE block referenced
# by both left and right.
# Need to do a deduplication to avoid repeated query.
merged_queries = select_left.queries[:-1].copy()
for query in select_right.queries[:-1]:
if query not in merged_queries:
merged_queries.append(copy.copy(query))

# When the cte optimization and the new compilation stage is enabled,
# the referred cte tables are propagated from left and right can have
# duplicated queries if there is a common CTE block referenced by
# both left and right.
referenced_ctes.update(select_left.referenced_ctes)
referenced_ctes.update(select_right.referenced_ctes)

post_actions = select_left.post_actions.copy()
for post_action in select_right.post_actions:
if post_action not in post_actions:
post_actions.append(copy.copy(post_action))
else:
merged_queries = select_left.queries[:-1] + select_right.queries[:-1]
post_actions = select_left.post_actions + select_right.post_actions

queries = merged_queries + [
Query(
sql_generator(
Expand Down
15 changes: 12 additions & 3 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,21 @@ def update_resolvable_node(
node.analyzer = query_generator

# update the pre_actions and post_actions for the set statement
node.pre_actions, node.post_actions = [], []
node.pre_actions, node.post_actions = None, None
for operand in node.set_operands:
if operand.selectable.pre_actions:
node.pre_actions.extend(operand.selectable.pre_actions)
if node.pre_actions is None:
node.pre_actions = []
for action in operand.selectable.pre_actions:
if action not in node.pre_actions:
node.pre_actions.append(action)

if operand.selectable.post_actions:
node.post_actions.extend(operand.selectable.post_actions)
if node.post_actions is None:
node.post_actions = []
for action in operand.selectable.post_actions:
if action not in node.post_actions:
node.post_actions.append(action)

elif isinstance(node, (SelectSnowflakePlan, SelectTableFunction)):
assert node.snowflake_plan is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def foo(row):
],
)
@sql_count_checker(
query_count=18,
query_count=13,
udtf_count=1,
join_count=3,
high_count_expected=True,
Expand Down
8 changes: 4 additions & 4 deletions tests/integ/modin/frame/test_iloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,13 +1052,13 @@ def iloc_helper(df):
)

high_count_reason = """
11 queries includes 5 queries to prepare the temp table for df, including create, insert,
drop the temp table (3.). and alter session to set and unset query_tag (2) and one select query.
6 queries includes queries to create, insert, and drop the temp table (3), alter session
to set and unset query_tag (2) and one select query.
Another 5 query to prepare the temp table for df again due to the fact it is used in another
join even though it is in the same query.
16 queries add extra 5 queries to prepare the temp table for key
11 queries add extra 5 queries to prepare the temp table for key
"""
query_count = 11 if len(key) < 300 else 16
query_count = 6 if len(key) < 300 else 11
_test_df_iloc_with_1k_shape(
native_df_1k_1k, iloc_helper, query_count, 2, high_count_reason
)
Expand Down
14 changes: 10 additions & 4 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_non_select_query_composition_unionall(session):
reason="This is testing query generation",
run=False,
)
def test_non_select_query_composition_self_union(session):
def test_non_select_query_composition_self_union(session, sql_simplifier_enabled):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
try:
session.sql(
Expand All @@ -487,7 +487,10 @@ def test_non_select_query_composition_self_union(session):
union = df.union(df).select('"name"').filter(col('"name"') == table_name)

assert len(union.collect()) == 1
assert len(union._plan.queries) == 3
if sql_simplifier_enabled:
assert len(union._plan.queries) == 3
else:
assert len(union._plan.queries) == 2
finally:
Utils.drop_table(session, table_name)

Expand All @@ -497,7 +500,7 @@ def test_non_select_query_composition_self_union(session):
reason="This is testing query generation",
run=False,
)
def test_non_select_query_composition_self_unionall(session):
def test_non_select_query_composition_self_unionall(session, sql_simplifier_enabled):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
try:
session.sql(
Expand All @@ -508,7 +511,10 @@ def test_non_select_query_composition_self_unionall(session):
union = df.union_all(df).select('"name"').filter(col('"name"') == table_name)

assert len(union.collect()) == 2
assert len(union._plan.queries) == 3
if sql_simplifier_enabled:
assert len(union._plan.queries) == 3
else:
assert len(union._plan.queries) == 2
finally:
Utils.drop_table(session, table_name)

Expand Down
2 changes: 1 addition & 1 deletion tests/integ/scala/test_snowflake_plan_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_multiple_queries(session):
Utils.drop_table(session, table_name2)


def test_execution_queries_and_queries(session):
def test_execution_queries_and_post_actions(session):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df1 = df.select("a", "b")
# create a df where cte optimization can be applied
Expand Down
17 changes: 3 additions & 14 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_unary(session, action):
lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"),
],
)
def test_binary(session, action, sql_simplifier_enabled):
def test_binary(session, action):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
check_result(session, action(df, df), expect_cte_optimized=True)

Expand All @@ -134,19 +134,8 @@ def test_binary(session, action, sql_simplifier_enabled):
check_result(session, df3, expect_cte_optimized=True)
plan_queries = df3.queries
# check the number of queries
binary_build_used = (not sql_simplifier_enabled) or (
"JOIN" in plan_queries["queries"][-1]
)
# TODO (SNOW-1569005): Deduplicate queries during plan build for binary operators. This currently
# is only fixed when _query_compilation_stage_enabled for build_binary.
if session._query_compilation_stage_enabled and binary_build_used:
num_queries = 3
num_post_actions = 1
else:
num_queries = 5
num_post_actions = 2
assert len(plan_queries["queries"]) == num_queries
assert len(plan_queries["post_actions"]) == num_post_actions
assert len(plan_queries["queries"]) == 3
assert len(plan_queries["post_actions"]) == 1


@pytest.mark.parametrize(
Expand Down

0 comments on commit eba49e7

Please sign in to comment.