Skip to content

Commit

Permalink
[SNOW-1731783] Refactor node query comparison for Repeated subquery e…
Browse files Browse the repository at this point in the history
…limination (#2437)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

SNOW-1731783

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

In the previous repeated subquery elimination, we updated the node
comparison for SnowflakePlan for Selectable to
```
    def __eq__(self, other: "SnowflakePlan") -> bool:
        if not isinstance(other, SnowflakePlan):
            return False
        if self._id is not None and other._id is not None:
            return isinstance(other, SnowflakePlan) and self._id == other._id
        else:
            return super().__eq__(other)

    def __hash__(self) -> int:
        return hash(self._id) if self._id else super().__hash__()
```
where the id is generated based on the query and query parameter, this
means two node are treated as the same if they have same type and same
query. This make sense when we do repeated subquery elimination, but not
expected by other transformations.

Refactor the comparison to make sure that we only use the id comparison
for repeated subquery eliminations, not for others.
  • Loading branch information
sfc-gh-yzou authored Oct 16, 2024
1 parent dc24147 commit 952599b
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 121 deletions.
109 changes: 80 additions & 29 deletions src/snowflake/snowpark/_internal/analyzer/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union
from typing import TYPE_CHECKING, Optional, Set, Union

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
SPACE,
Expand All @@ -24,11 +24,9 @@
TreeNode = Union[SnowflakePlan, Selectable]


def find_duplicate_subtrees(
root: "TreeNode",
) -> Tuple[Set["TreeNode"], Dict["TreeNode", Set["TreeNode"]]]:
def find_duplicate_subtrees(root: "TreeNode") -> Set[str]:
"""
Returns a set containing all duplicate subtrees in query plan tree.
Returns a set of TreeNode encoded_id that indicates all duplicate subtrees in query plan tree.
The root of a duplicate subtree is defined as a duplicate node, if
- it appears more than once in the tree, AND
- one of its parent is unique (only appear once) in the tree, OR
Expand All @@ -49,8 +47,8 @@ def find_duplicate_subtrees(
This function is used to only include nodes that should be converted to CTEs.
"""
node_count_map = defaultdict(int)
node_parents_map = defaultdict(set)
id_count_map = defaultdict(int)
id_parents_map = defaultdict(set)

def traverse(root: "TreeNode") -> None:
"""
Expand All @@ -60,32 +58,39 @@ def traverse(root: "TreeNode") -> None:
while len(current_level) > 0:
next_level = []
for node in current_level:
node_count_map[node] += 1
id_count_map[node.encoded_node_id_with_query] += 1
for child in node.children_plan_nodes:
node_parents_map[child].add(node)
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
)
next_level.append(child)
current_level = next_level

def is_duplicate_subtree(node: "TreeNode") -> bool:
is_duplicate_node = node_count_map[node] > 1
def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
node_count_map[n] == 1 for n in node_parents_map[node]
id_count_map[id] == 1
for id in id_parents_map[encoded_node_id_with_query]
)
if is_any_parent_unique_node:
return True
else:
has_multi_parents = len(node_parents_map[node]) > 1
has_multi_parents = len(id_parents_map[encoded_node_id_with_query]) > 1
if has_multi_parents:
return True
return False

traverse(root)
duplicated_node = {node for node in node_count_map if is_duplicate_subtree(node)}
return duplicated_node, node_parents_map
duplicated_node = {
encoded_node_id_with_query
for encoded_node_id_with_query in id_count_map
if is_duplicate_subtree(encoded_node_id_with_query)
}
return duplicated_node


def create_cte_query(root: "TreeNode", duplicate_plan_set: Set["TreeNode"]) -> str:
def create_cte_query(root: "TreeNode", duplicated_node_ids: Set[str]) -> str:
from snowflake.snowpark._internal.analyzer.select_statement import Selectable

plan_to_query_map = {}
Expand All @@ -110,32 +115,41 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None:

while stack2:
node = stack2.pop()
if node in plan_to_query_map:
if node.encoded_node_id_with_query in plan_to_query_map:
continue

if not node.children_plan_nodes or not node.placeholder_query:
plan_to_query_map[node] = (
plan_to_query_map[node.encoded_node_id_with_query] = (
node.sql_query
if isinstance(node, Selectable)
else node.queries[-1].sql
)
else:
plan_to_query_map[node] = node.placeholder_query
plan_to_query_map[
node.encoded_node_id_with_query
] = node.placeholder_query
for child in node.children_plan_nodes:
# replace the placeholder (id) with child query
plan_to_query_map[node] = plan_to_query_map[node].replace(
child._id, plan_to_query_map[child]
plan_to_query_map[
node.encoded_node_id_with_query
] = plan_to_query_map[node.encoded_node_id_with_query].replace(
child.encoded_query_id,
plan_to_query_map[child.encoded_node_id_with_query],
)

# duplicate subtrees will be converted CTEs
if node in duplicate_plan_set:
if node.encoded_node_id_with_query in duplicated_node_ids:
# when a subquery is converted a CTE to with clause,
# it will be replaced by `SELECT * from TEMP_TABLE` in the original query
table_name = random_name_for_temp_object(TempObjectType.CTE)
select_stmt = project_statement([], table_name)
duplicate_plan_to_table_name_map[node] = table_name
duplicate_plan_to_cte_map[node] = plan_to_query_map[node]
plan_to_query_map[node] = select_stmt
duplicate_plan_to_table_name_map[
node.encoded_node_id_with_query
] = table_name
duplicate_plan_to_cte_map[
node.encoded_node_id_with_query
] = plan_to_query_map[node.encoded_node_id_with_query]
plan_to_query_map[node.encoded_node_id_with_query] = select_stmt

build_plan_to_query_map_in_post_order(root)

Expand All @@ -144,16 +158,53 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None:
list(duplicate_plan_to_cte_map.values()),
list(duplicate_plan_to_table_name_map.values()),
)
final_query = with_stmt + SPACE + plan_to_query_map[root]
final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_node_id_with_query]
return final_query


def encode_id(
query: str, query_params: Optional[Sequence[Any]] = None
) -> Optional[str]:
def encoded_query_id(node) -> Optional[str]:
"""
Encode the query and its query parameter into an id using sha256.
Returns:
If encode succeed, return the first 10 encoded value.
Otherwise, return None
"""
from snowflake.snowpark._internal.analyzer.select_statement import SelectSQL
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan

if isinstance(node, SnowflakePlan):
query = node.queries[-1].sql
query_params = node.queries[-1].params
elif isinstance(node, SelectSQL):
# For SelectSql, The original SQL is used to encode its ID,
# which might be a non-select SQL.
query = node.original_sql
query_params = node.query_params
else:
query = node.sql_query
query_params = node.query_params

string = f"{query}#{query_params}" if query_params else query
try:
return hashlib.sha256(string.encode()).hexdigest()[:10]
except Exception as ex:
logging.warning(f"Encode SnowflakePlan ID failed: {ex}")
return None


def encode_node_id_with_query(node: "TreeNode") -> str:
"""
Encode a for the given TreeNode.
If query and query parameters can be encoded successfully using sha256,
return the encoded query id + node_type_name.
Otherwise, return the original node id.
"""
query_id = encoded_query_id(node)
if query_id is not None:
node_type_name = type(node).__name__
return f"{query_id}_{node_type_name}"
else:
return str(id(node))
57 changes: 26 additions & 31 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
)

import snowflake.snowpark._internal.utils
from snowflake.snowpark._internal.analyzer.cte_utils import encode_id
from snowflake.snowpark._internal.analyzer.cte_utils import (
encode_node_id_with_query,
encoded_query_id,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
PlanState,
Expand Down Expand Up @@ -239,17 +242,6 @@ def __init__(
self._api_calls = api_calls.copy() if api_calls is not None else None
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None

def __eq__(self, other: "Selectable") -> bool:
if not isinstance(other, Selectable):
return False
if self._id is not None and other._id is not None:
return type(self) is type(other) and self._id == other._id
else:
return super().__eq__(other)

def __hash__(self) -> int:
return hash(self._id) if self._id else super().__hash__()

@property
@abstractmethod
def sql_query(self) -> str:
Expand All @@ -263,9 +255,20 @@ def placeholder_query(self) -> Optional[str]:
pass

@cached_property
def _id(self) -> Optional[str]:
"""Returns the id of this Selectable logical plan."""
return encode_id(self.sql_query, self.query_params)
def encoded_node_id_with_query(self) -> str:
"""
Returns an encoded node id of this Selectable logical plan.
Note that the encoding algorithm uses queries as content, and returns the same id for
two selectable node with same queries. This is currently used by repeated subquery
elimination to detect two nodes with same query, please use it with careful.
"""
return encode_node_id_with_query(self)

@cached_property
def encoded_query_id(self) -> Optional[str]:
"""Returns an encoded id of the queries for this Selectable logical plan."""
return encoded_query_id(self)

@property
@abstractmethod
Expand Down Expand Up @@ -506,14 +509,6 @@ def sql_query(self) -> str:
def placeholder_query(self) -> Optional[str]:
return None

@property
def _id(self) -> Optional[str]:
"""
Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID,
which might be a non-select SQL.
"""
return encode_id(self.original_sql, self.query_params)

@property
def query_params(self) -> Optional[Sequence[Any]]:
return self._query_param
Expand Down Expand Up @@ -591,9 +586,9 @@ def sql_query(self) -> str:
def placeholder_query(self) -> Optional[str]:
return self._snowflake_plan.placeholder_query

@property
def _id(self) -> Optional[str]:
return self._snowflake_plan._id
@cached_property
def encoded_query_id(self) -> Optional[str]:
return self._snowflake_plan.encoded_query_id

@property
def schema_query(self) -> Optional[str]:
Expand Down Expand Up @@ -793,9 +788,9 @@ def sql_query(self) -> str:
if (
self.analyzer.session._cte_optimization_enabled
and (not self.analyzer.session._query_compilation_stage_enabled)
and self.from_._id
and self.from_.encoded_query_id
):
placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}"
placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}"
self._sql_query = self.placeholder_query.replace(placeholder, from_clause)
else:
where_clause = (
Expand Down Expand Up @@ -825,7 +820,7 @@ def sql_query(self) -> str:
def placeholder_query(self) -> str:
if self._placeholder_query:
return self._placeholder_query
from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}"
from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}"
if not self.has_clause and not self.projection:
self._placeholder_query = from_clause
return self._placeholder_query
Expand Down Expand Up @@ -1429,9 +1424,9 @@ def sql_query(self) -> str:
@property
def placeholder_query(self) -> Optional[str]:
if not self._placeholder_query:
sql = f"({self.set_operands[0].selectable._id})"
sql = f"({self.set_operands[0].selectable.encoded_query_id})"
for i in range(1, len(self.set_operands)):
sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable._id})"
sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_query_id})"
self._placeholder_query = sql
return self._placeholder_query

Expand Down
39 changes: 17 additions & 22 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
)
from snowflake.snowpark._internal.analyzer.cte_utils import (
create_cte_query,
encode_id,
encode_node_id_with_query,
encoded_query_id,
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.analyzer.expression import Attribute
Expand Down Expand Up @@ -257,9 +258,13 @@ def __init__(
# It is used for optimization, by replacing a subquery with a CTE
self.placeholder_query = placeholder_query
# encode an id for CTE optimization. This is generated based on the main
# query and the associated query parameters. We use this id for equality comparison
# to determine if two plans are the same.
self._id = encode_id(queries[-1].sql, queries[-1].params)
# query, query parameters and the node type. We use this id for equality
# comparison to determine if two plans are the same.
self.encoded_node_id_with_query = encode_node_id_with_query(self)
# encode id for the main query and query parameters, this is currently only used
# by the create_cte_query process.
# TODO (SNOW-1541096) remove this filed along removing the old cte implementation
self.encoded_query_id = encoded_query_id(self)
self.referenced_ctes: Set[WithQueryBlock] = (
referenced_ctes.copy() if referenced_ctes else set()
)
Expand All @@ -272,17 +277,6 @@ def __init__(
if session.reduce_describe_query_enabled and self.source_plan is not None:
self._attributes = infer_metadata(self.source_plan)

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
return False
if self._id is not None and other._id is not None:
return isinstance(other, SnowflakePlan) and self._id == other._id
else:
return super().__eq__(other)

def __hash__(self) -> int:
return hash(self._id) if self._id else super().__hash__()

@property
def uuid(self) -> str:
return self._uuid
Expand Down Expand Up @@ -354,7 +348,7 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan":
return self

# if there is no duplicate node, no optimization will be performed
duplicate_plan_set, _ = find_duplicate_subtrees(self)
duplicate_plan_set = find_duplicate_subtrees(self)
if not duplicate_plan_set:
return self

Expand Down Expand Up @@ -425,7 +419,7 @@ def output_dict(self) -> Dict[str, Any]:

@cached_property
def num_duplicate_nodes(self) -> int:
duplicated_nodes, _ = find_duplicate_subtrees(self)
duplicated_nodes = find_duplicate_subtrees(self)
return len(duplicated_nodes)

@cached_property
Expand Down Expand Up @@ -597,8 +591,9 @@ def build(
new_schema_query = schema_query or sql_generator(child.schema_query)

placeholder_query = (
sql_generator(select_child._id)
if self.session._cte_optimization_enabled and select_child._id is not None
sql_generator(select_child.encoded_query_id)
if self.session._cte_optimization_enabled
and select_child.encoded_query_id is not None
else None
)

Expand Down Expand Up @@ -636,10 +631,10 @@ def build_binary(
schema_query = sql_generator(left_schema_query, right_schema_query)

placeholder_query = (
sql_generator(select_left._id, select_right._id)
sql_generator(select_left.encoded_query_id, select_right.encoded_query_id)
if self.session._cte_optimization_enabled
and select_left._id is not None
and select_right._id is not None
and select_left.encoded_query_id is not None
and select_right.encoded_query_id is not None
else None
)

Expand Down
Loading

0 comments on commit 952599b

Please sign in to comment.