diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index 1746018c7c9..c075b4fbc77 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -74,6 +74,12 @@ def traverse(root: "TreeNode") -> None: current_level = next_level def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: + # when a sql query is a select statement, its encoded_node_id_with_query + # contains _, which is used to separate the query id and node type name. + is_valid_candidate = "_" in encoded_node_id_with_query + if not is_valid_candidate: + return False + is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1 if is_duplicate_node: is_any_parent_unique_node = any( diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index ac9df3b1c2e..67d118012cb 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -118,7 +118,8 @@ def test_no_valid_nodes_found(session, caplog): def test_large_query_breakdown_external_cte_ref(session): session._cte_optimization_enabled = True - if not session.sql_simplifier_enabled: + sql_simplifier_enabled = session.sql_simplifier_enabled + if not sql_simplifier_enabled: set_bounds(session, 50, 90) base_select = session.sql("select 1 as A, 2 as B") @@ -153,9 +154,9 @@ def test_large_query_breakdown_external_cte_ref(session): assert summary_value["breakdown_failure_summary"] == [ { "num_external_cte_ref_nodes": 2, - "num_non_pipeline_breaker_nodes": 4, + "num_non_pipeline_breaker_nodes": 4 if sql_simplifier_enabled else 2, "num_nodes_below_lower_bound": 28, - "num_nodes_above_upper_bound": 1, + "num_nodes_above_upper_bound": 1 if sql_simplifier_enabled else 0, "num_valid_nodes": 0, "num_partitions_made": 0, } @@ -165,7 +166,7 @@ def test_large_query_breakdown_external_cte_ref(session): def test_breakdown_at_with_query_node(session): session._cte_optimization_enabled = True if not session.sql_simplifier_enabled: - pass + set_bounds(session, 40, 80) df0 = session.sql("select 1 as A, 2 as B") for i in range(7): @@ -189,6 +190,9 @@ def test_breakdown_at_with_query_node(session): def test_large_query_breakdown_with_cte_optimization(session): """Test large query breakdown works with cte optimized plan""" + if not session.cte_optimization_enabled: + pytest.skip("CTE optimization is not enabled") + if not session.sql_simplifier_enabled: # the complexity bounds are updated since nested selected calculation is not supported # when sql simplifier disabled diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 00fb2e16663..eba8e8b5477 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -757,8 +757,7 @@ def process_data(df_, thread_id): assert match is not None, query file_format_name = match.group() unique_create_file_format_queries.add(file_format_name) - else: - assert query.startswith("DROP FILE FORMAT") + elif query.startswith("DROP FILE FORMAT"): match = re.search(r"SNOWPARK_TEMP_FILE_FORMAT_[\w]+", query) assert match is not None, query file_format_name = match.group() diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 09a1768f282..d2617dd5d86 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -23,7 +23,7 @@ def test_case1(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node.encoded_node_id_with_query = i + node.encoded_node_id_with_query = f"{i}_{i}" node.source_plan = None if i == 5: node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: 80000} @@ -39,7 +39,7 @@ def test_case1(): nodes[5].children_plan_nodes = [] nodes[6].children_plan_nodes = [] - expected_duplicate_subtree_ids = {2, 5} + expected_duplicate_subtree_ids = {"2_2", "5_5"} expected_repeated_node_complexity = [0, 3, 0, 2, 0, 0, 0] return nodes[0], expected_duplicate_subtree_ids, expected_repeated_node_complexity @@ -47,7 +47,7 @@ def test_case1(): def test_case2(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node.encoded_node_id_with_query = i + node.encoded_node_id_with_query = f"{i}_{i}" node.source_plan = None if i == 2: node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: 2000000} @@ -65,7 +65,7 @@ def test_case2(): nodes[5].children_plan_nodes = [] nodes[6].children_plan_nodes = [nodes[4], nodes[4]] - expected_duplicate_subtree_ids = {2, 4, 6} + expected_duplicate_subtree_ids = {"2_2", "4_4", "6_6"} expected_repeated_node_complexity = [0, 0, 0, 0, 2, 8, 2] return nodes[0], expected_duplicate_subtree_ids, expected_repeated_node_complexity