Skip to content

Commit

Permalink
[SNOW-1632896] Detect valid condition for flatten nested select proje…
Browse files Browse the repository at this point in the history
…ction complexity (#2326)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

NOW-1632896

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.
snowflake optimization can merge the select statement when possible, to
get a more accurate estimation of the query plan complexity, we want to
detect such case and adjust the complexity calculation.
For example, SELECT COL1 + 2 AS COL1, COL2 FROM (SELECT COL1 + 1 AS
COL1, COL2 FROM TEST_TABLE) can be merged as SELECT (COL1 + 1) + COL1 +
2 AS COL1, COL2 FROM TEST_TABLE.

A parameter _try_merge_projection_complexity is added to SelectStatement
to indicate that we should try to merge the select projection complexity
with the child (from).
  • Loading branch information
sfc-gh-yzou authored Sep 24, 2024
1 parent 272e4e1 commit 965ce9d
Show file tree
Hide file tree
Showing 3 changed files with 543 additions and 0 deletions.
103 changes: 103 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
Alias,
UnresolvedAlias,
)
from snowflake.snowpark._internal.select_projection_complexity_utils import (
has_invalid_projection_merge_functions,
)
from snowflake.snowpark._internal.utils import is_sql_select_statement

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
Expand Down Expand Up @@ -640,6 +643,19 @@ def __init__(
self.from_.api_calls.copy() if self.from_.api_calls is not None else None
) # will be replaced by new api calls if any operation.
self._placeholder_query = None
# indicate whether we should try to merge the projection complexity of the current
# SelectStatement with the projection complexity of from_ during the calculation of
# node complexity. For example:
# SELECT COL1 + 2 as COL1, COL2 FROM (SELECT COL1 + 3 AS COL1, COL2 FROM TABLE_TEST)
# can be merged as follows with snowflake:
# SELECT (COL1 + 3) + 2 AS COL1, COLS FROM TABLE_TEST
# Therefore, the plan complexity during compilation will change, and the result plan
# complexity is can be calculated by merging the projection complexity of the two SELECTS.
#
# In Snowpark, we do not generate the query after merging two selects. Flag
# _merge_projection_complexity_with_subquery is used to indicate that it is valid to merge
# the projection complexity of current SelectStatement with subquery.
self._merge_projection_complexity_with_subquery = False

def __copy__(self):
new = SelectStatement(
Expand All @@ -663,6 +679,9 @@ def __copy__(self):
new.df_aliased_col_name_to_real_col_name = (
self.df_aliased_col_name_to_real_col_name
)
new._merge_projection_complexity_with_subquery = (
self._merge_projection_complexity_with_subquery
)

return new

Expand All @@ -682,6 +701,9 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
copied._projection_in_str = self._projection_in_str
copied._query_params = deepcopy(self._query_params)
copied._merge_projection_complexity_with_subquery = (
self._merge_projection_complexity_with_subquery
)
return copied

@property
Expand Down Expand Up @@ -918,6 +940,8 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
self.expr_to_alias
) # use copy because we don't want two plans to share the same list. If one mutates, the other ones won't be impacted.
new.flatten_disabled = self.flatten_disabled
# no need to flatten the projection complexity since the select projection is already flattened.
new._merge_projection_complexity_with_subquery = False
return new
disable_next_level_flatten = False
new_column_states = derive_column_states_from_subquery(cols, self)
Expand Down Expand Up @@ -992,10 +1016,21 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
new.from_ = self.from_.to_subqueryable()
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
# there is no need to flatten the projection complexity since the child
# select projection is already flattened with the current select.
new._merge_projection_complexity_with_subquery = False
else:
new = SelectStatement(
projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer
)
new._merge_projection_complexity_with_subquery = (
can_select_projection_complexity_be_merged(
cols,
new_column_states,
self,
)
)

new.flatten_disabled = disable_next_level_flatten
assert new.projection is not None
new._column_states = derive_column_states_from_subquery(
Expand All @@ -1022,6 +1057,7 @@ def filter(self, col: Expression) -> "SelectStatement":
new.post_actions = new.from_.post_actions
new.column_states = self.column_states
new.where = And(self.where, col) if self.where is not None else col
new._merge_projection_complexity_with_subquery = False
else:
new = SelectStatement(
from_=self.to_subqueryable(), where=col, analyzer=self.analyzer
Expand All @@ -1044,6 +1080,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
new.post_actions = new.from_.post_actions
new.order_by = cols + (self.order_by or [])
new.column_states = self.column_states
new._merge_projection_complexity_with_subquery = False
else:
new = SelectStatement(
from_=self.to_subqueryable(),
Expand Down Expand Up @@ -1130,6 +1167,7 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement":
new.column_states = self.column_states
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new._merge_projection_complexity_with_subquery = False
return new


Expand Down Expand Up @@ -1433,6 +1471,71 @@ def can_clause_dependent_columns_flatten(
return True


def can_select_projection_complexity_be_merged(
cols: List[Expression],
column_states: Optional[ColumnStateDict],
subquery: Selectable,
) -> bool:
"""
Check whether projection complexity of subquery can be merged with the current
projection columns.
Args:
cols: the projection column expressions of the current select
column_states: the column states extracted out of the current projection column
on top of subquery.
subquery: the subquery where the current select is performed on top of
"""
if not subquery.analyzer.session._large_query_breakdown_enabled:
return False

# only merge of nested select statement is supported, and subquery must be
# a SelectStatement
if column_states is None or (not isinstance(subquery, SelectStatement)):
return False # pragma: no cover

if len(cols) != len(column_states.projection):
# Failed to extract the attributes of some columns
return False # pragma: no cover

if subquery._column_states is None:
return False # pragma: no cover

# It is not valid to merge the projection complexity if:
# 1) exist a column without state extracted
# 2) exist a column that dependents on columns from the same level
# 3) exist a column that dependents on $. Theoretically, this could be
# valid, but extra analysis is required to check the validness.
# 4) all dependent column in the projection expression is an active column
# from the subquery
for proj in column_states.projection:
column_state = column_states.get(proj.name)
if column_state is None:
return False # pragma: no cover
if column_state.depend_on_same_level:
return False
if column_state.dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
return False
if column_state.dependent_columns != COLUMN_DEPENDENCY_ALL:
for dependent_col in column_state.dependent_columns:
if dependent_col not in subquery._column_states.active_columns:
return False # pragma: no cover

# check if the current select have filter, order by, or limit
if subquery.where or subquery.order_by or subquery.limit_ or subquery.offset:
return False

# check if the projection expression contain invalid functions
if has_invalid_projection_merge_functions(cols):
return False

# check if subquery projection expression contain invalid functions
if has_invalid_projection_merge_functions(subquery.projection):
return False

return True


def initiate_column_states(
column_attrs: List[Attribute],
analyzer: "Analyzer",
Expand Down
Loading

0 comments on commit 965ce9d

Please sign in to comment.