Skip to content

Commit

Permalink
Aalam fix daily preocmmit (#2516)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Oct 26, 2024
1 parent aee8d51 commit 51c4938
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def traverse(root: "TreeNode") -> None:
current_level = next_level

def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
# when a sql query is a select statement, its encoded_node_id_with_query
# contains _, which is used to separate the query id and node type name.
is_valid_candidate = "_" in encoded_node_id_with_query
if not is_valid_candidate:
return False

is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
Expand Down
12 changes: 8 additions & 4 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_no_valid_nodes_found(session, caplog):

def test_large_query_breakdown_external_cte_ref(session):
session._cte_optimization_enabled = True
if not session.sql_simplifier_enabled:
sql_simplifier_enabled = session.sql_simplifier_enabled
if not sql_simplifier_enabled:
set_bounds(session, 50, 90)

base_select = session.sql("select 1 as A, 2 as B")
Expand Down Expand Up @@ -153,9 +154,9 @@ def test_large_query_breakdown_external_cte_ref(session):
assert summary_value["breakdown_failure_summary"] == [
{
"num_external_cte_ref_nodes": 2,
"num_non_pipeline_breaker_nodes": 4,
"num_non_pipeline_breaker_nodes": 4 if sql_simplifier_enabled else 2,
"num_nodes_below_lower_bound": 28,
"num_nodes_above_upper_bound": 1,
"num_nodes_above_upper_bound": 1 if sql_simplifier_enabled else 0,
"num_valid_nodes": 0,
"num_partitions_made": 0,
}
Expand All @@ -165,7 +166,7 @@ def test_large_query_breakdown_external_cte_ref(session):
def test_breakdown_at_with_query_node(session):
session._cte_optimization_enabled = True
if not session.sql_simplifier_enabled:
pass
set_bounds(session, 40, 80)

df0 = session.sql("select 1 as A, 2 as B")
for i in range(7):
Expand All @@ -189,6 +190,9 @@ def test_breakdown_at_with_query_node(session):

def test_large_query_breakdown_with_cte_optimization(session):
"""Test large query breakdown works with cte optimized plan"""
if not session.cte_optimization_enabled:
pytest.skip("CTE optimization is not enabled")

if not session.sql_simplifier_enabled:
# the complexity bounds are updated since nested selected calculation is not supported
# when sql simplifier disabled
Expand Down
3 changes: 1 addition & 2 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,7 @@ def process_data(df_, thread_id):
assert match is not None, query
file_format_name = match.group()
unique_create_file_format_queries.add(file_format_name)
else:
assert query.startswith("DROP FILE FORMAT")
elif query.startswith("DROP FILE FORMAT"):
match = re.search(r"SNOWPARK_TEMP_FILE_FORMAT_[\w]+", query)
assert match is not None, query
file_format_name = match.group()
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def test_case1():
nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)]
for i, node in enumerate(nodes):
node.encoded_node_id_with_query = i
node.encoded_node_id_with_query = f"{i}_{i}"
node.source_plan = None
if i == 5:
node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: 80000}
Expand All @@ -39,15 +39,15 @@ def test_case1():
nodes[5].children_plan_nodes = []
nodes[6].children_plan_nodes = []

expected_duplicate_subtree_ids = {2, 5}
expected_duplicate_subtree_ids = {"2_2", "5_5"}
expected_repeated_node_complexity = [0, 3, 0, 2, 0, 0, 0]
return nodes[0], expected_duplicate_subtree_ids, expected_repeated_node_complexity


def test_case2():
nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)]
for i, node in enumerate(nodes):
node.encoded_node_id_with_query = i
node.encoded_node_id_with_query = f"{i}_{i}"
node.source_plan = None
if i == 2:
node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: 2000000}
Expand All @@ -65,7 +65,7 @@ def test_case2():
nodes[5].children_plan_nodes = []
nodes[6].children_plan_nodes = [nodes[4], nodes[4]]

expected_duplicate_subtree_ids = {2, 4, 6}
expected_duplicate_subtree_ids = {"2_2", "4_4", "6_6"}
expected_repeated_node_complexity = [0, 0, 0, 0, 2, 8, 2]
return nodes[0], expected_duplicate_subtree_ids, expected_repeated_node_complexity

Expand Down

0 comments on commit 51c4938

Please sign in to comment.