diff --git a/CHANGELOG.md b/CHANGELOG.md index ce92524bb7b..c4a6eebe1cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Bug Fixes - Fixed a bug in `session.read.csv` that caused an error when setting `PARSE_HEADER = True` in an externally defined file format. +- Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries. ### Snowpark Local Testing Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index c976b0f977a..55d33577d55 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1224,11 +1224,15 @@ def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None: if operand.selectable.pre_actions: if not self.pre_actions: self.pre_actions = [] - self.pre_actions.extend(operand.selectable.pre_actions) + for action in operand.selectable.pre_actions: + if action not in self.pre_actions: + self.pre_actions.append(copy(action)) if operand.selectable.post_actions: if not self.post_actions: self.post_actions = [] - self.post_actions.extend(operand.selectable.post_actions) + for action in operand.selectable.post_actions: + if action not in self.post_actions: + self.post_actions.append(copy(action)) self._nodes.append(operand.selectable) def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006 diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 68b2ea29f3a..7c79362b58f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -614,32 +614,29 @@ def build_binary( } api_calls = [*select_left.api_calls, *select_right.api_calls] + # Need to do a deduplication to avoid repeated query. + merged_queries = select_left.queries[:-1].copy() + for query in select_right.queries[:-1]: + if query not in merged_queries: + merged_queries.append(copy.copy(query)) + + post_actions = select_left.post_actions.copy() + for post_action in select_right.post_actions: + if post_action not in post_actions: + post_actions.append(copy.copy(post_action)) + referenced_ctes: Set[str] = set() if ( self.session.cte_optimization_enabled and self.session._query_compilation_stage_enabled ): - # When the cte optimization and the new compilation stage is enabled, the - # queries, referred cte tables, and post actions propagated from - # left and right can have duplicated queries if there is a common CTE block referenced - # by both left and right. - # Need to do a deduplication to avoid repeated query. - merged_queries = select_left.queries[:-1].copy() - for query in select_right.queries[:-1]: - if query not in merged_queries: - merged_queries.append(copy.copy(query)) - + # When the cte optimization and the new compilation stage is enabled, + # the referred cte tables are propagated from left and right can have + # duplicated queries if there is a common CTE block referenced by + # both left and right. referenced_ctes.update(select_left.referenced_ctes) referenced_ctes.update(select_right.referenced_ctes) - post_actions = select_left.post_actions.copy() - for post_action in select_right.post_actions: - if post_action not in post_actions: - post_actions.append(copy.copy(post_action)) - else: - merged_queries = select_left.queries[:-1] + select_right.queries[:-1] - post_actions = select_left.post_actions + select_right.post_actions - queries = merged_queries + [ Query( sql_generator( diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index d478a1e2f17..579c3b8e5d6 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -246,12 +246,21 @@ def update_resolvable_node( node.analyzer = query_generator # update the pre_actions and post_actions for the set statement - node.pre_actions, node.post_actions = [], [] + node.pre_actions, node.post_actions = None, None for operand in node.set_operands: if operand.selectable.pre_actions: - node.pre_actions.extend(operand.selectable.pre_actions) + if node.pre_actions is None: + node.pre_actions = [] + for action in operand.selectable.pre_actions: + if action not in node.pre_actions: + node.pre_actions.append(action) + if operand.selectable.post_actions: - node.post_actions.extend(operand.selectable.post_actions) + if node.post_actions is None: + node.post_actions = [] + for action in operand.selectable.post_actions: + if action not in node.post_actions: + node.post_actions.append(action) elif isinstance(node, (SelectSnowflakePlan, SelectTableFunction)): assert node.snowflake_plan is not None diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 4851b45ee3d..1014cae44c9 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -702,7 +702,7 @@ def foo(row): ], ) @sql_count_checker( - query_count=18, + query_count=13, udtf_count=1, join_count=3, high_count_expected=True, diff --git a/tests/integ/modin/frame/test_iloc.py b/tests/integ/modin/frame/test_iloc.py index 5b382fbc3dd..dd090683835 100644 --- a/tests/integ/modin/frame/test_iloc.py +++ b/tests/integ/modin/frame/test_iloc.py @@ -1052,13 +1052,13 @@ def iloc_helper(df): ) high_count_reason = """ - 11 queries includes 5 queries to prepare the temp table for df, including create, insert, - drop the temp table (3.). and alter session to set and unset query_tag (2) and one select query. + 6 queries includes queries to create, insert, and drop the temp table (3), alter session + to set and unset query_tag (2) and one select query. Another 5 query to prepare the temp table for df again due to the fact it is used in another join even though it is in the same query. - 16 queries add extra 5 queries to prepare the temp table for key + 11 queries add extra 5 queries to prepare the temp table for key """ - query_count = 11 if len(key) < 300 else 16 + query_count = 6 if len(key) < 300 else 11 _test_df_iloc_with_1k_shape( native_df_1k_1k, iloc_helper, query_count, 2, high_count_reason ) diff --git a/tests/integ/scala/test_dataframe_suite.py b/tests/integ/scala/test_dataframe_suite.py index 79d655d7dd5..09329d9724a 100644 --- a/tests/integ/scala/test_dataframe_suite.py +++ b/tests/integ/scala/test_dataframe_suite.py @@ -476,7 +476,7 @@ def test_non_select_query_composition_unionall(session): reason="This is testing query generation", run=False, ) -def test_non_select_query_composition_self_union(session): +def test_non_select_query_composition_self_union(session, sql_simplifier_enabled): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: session.sql( @@ -487,7 +487,10 @@ def test_non_select_query_composition_self_union(session): union = df.union(df).select('"name"').filter(col('"name"') == table_name) assert len(union.collect()) == 1 - assert len(union._plan.queries) == 3 + if sql_simplifier_enabled: + assert len(union._plan.queries) == 3 + else: + assert len(union._plan.queries) == 2 finally: Utils.drop_table(session, table_name) @@ -497,7 +500,7 @@ def test_non_select_query_composition_self_union(session): reason="This is testing query generation", run=False, ) -def test_non_select_query_composition_self_unionall(session): +def test_non_select_query_composition_self_unionall(session, sql_simplifier_enabled): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) try: session.sql( @@ -508,7 +511,10 @@ def test_non_select_query_composition_self_unionall(session): union = df.union_all(df).select('"name"').filter(col('"name"') == table_name) assert len(union.collect()) == 2 - assert len(union._plan.queries) == 3 + if sql_simplifier_enabled: + assert len(union._plan.queries) == 3 + else: + assert len(union._plan.queries) == 2 finally: Utils.drop_table(session, table_name) diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index 3ba5d8a029d..7311ddbdc3c 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -128,7 +128,7 @@ def test_multiple_queries(session): Utils.drop_table(session, table_name2) -def test_execution_queries_and_queries(session): +def test_execution_queries_and_post_actions(session): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df1 = df.select("a", "b") # create a df where cte optimization can be applied diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 48060947c39..9345d0528fd 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -116,7 +116,7 @@ def test_unary(session, action): lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"), ], ) -def test_binary(session, action, sql_simplifier_enabled): +def test_binary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) check_result(session, action(df, df), expect_cte_optimized=True) @@ -134,19 +134,8 @@ def test_binary(session, action, sql_simplifier_enabled): check_result(session, df3, expect_cte_optimized=True) plan_queries = df3.queries # check the number of queries - binary_build_used = (not sql_simplifier_enabled) or ( - "JOIN" in plan_queries["queries"][-1] - ) - # TODO (SNOW-1569005): Deduplicate queries during plan build for binary operators. This currently - # is only fixed when _query_compilation_stage_enabled for build_binary. - if session._query_compilation_stage_enabled and binary_build_used: - num_queries = 3 - num_post_actions = 1 - else: - num_queries = 5 - num_post_actions = 2 - assert len(plan_queries["queries"]) == num_queries - assert len(plan_queries["post_actions"]) == num_post_actions + assert len(plan_queries["queries"]) == 3 + assert len(plan_queries["post_actions"]) == 1 @pytest.mark.parametrize(