Skip to content

Commit

Permalink
[SNOW-1632898] Adjust SelectStatement projection complexity calculati…
Browse files Browse the repository at this point in the history
…on (#2340)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

SNOW-1632898

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

1) adjust the projection complexity calculation, and cumulative
complexity for selectStatement when
_merge_projection_complexity_with_subquery set to True.

2) update the complexity calculation for snowflake plan to directly get
the complexity from source plan, and update the reset to reset the
cumulative complexity.
  • Loading branch information
sfc-gh-yzou authored Sep 28, 2024
1 parent cac2980 commit e8765d2
Show file tree
Hide file tree
Showing 9 changed files with 719 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def sum_node_complexities(
return dict(counter_sum)


def subtract_complexities(
complexities1: Dict[PlanNodeCategory, int],
complexities2: Dict[PlanNodeCategory, int],
) -> Dict[PlanNodeCategory, int]:
"""
This is a helper function for complexities1 - complexities2.
"""

result_complexities = complexities1.copy()
for key, value in complexities2.items():
if key in result_complexities:
result_complexities[key] -= value
else:
result_complexities[key] = -value

return result_complexities


def get_complexity_score(
cumulative_node_complexity: Dict[PlanNodeCategory, int]
) -> int:
Expand Down
120 changes: 109 additions & 11 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from snowflake.snowpark._internal.analyzer.cte_utils import encode_id
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
subtract_complexities,
sum_node_complexities,
)
from snowflake.snowpark._internal.analyzer.table_function import (
Expand Down Expand Up @@ -656,6 +657,11 @@ def __init__(
# _merge_projection_complexity_with_subquery is used to indicate that it is valid to merge
# the projection complexity of current SelectStatement with subquery.
self._merge_projection_complexity_with_subquery = False
# cached list of projection complexities, each projection complexity is adjusted
# with the subquery projection if _merge_projection_complexity_with_subquery is True.
self._projection_complexities: Optional[
List[Dict[PlanNodeCategory, int]]
] = None

def __copy__(self):
new = SelectStatement(
Expand Down Expand Up @@ -704,6 +710,11 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
copied._merge_projection_complexity_with_subquery = (
self._merge_projection_complexity_with_subquery
)
copied._projection_complexities = (
deepcopy(self._projection_complexities)
if not self._projection_complexities
else None
)
return copied

@property
Expand Down Expand Up @@ -844,17 +855,7 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
complexity = {}
# projection component
complexity = (
sum_node_complexities(
complexity,
*(
getattr(
expr,
"cumulative_node_complexity",
{PlanNodeCategory.COLUMN: 1},
) # type: ignore
for expr in self.projection
),
)
sum_node_complexities(*self.projection_complexities)
if self.projection
else complexity
)
Expand Down Expand Up @@ -894,6 +895,27 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
)
return complexity

@property
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self._cumulative_node_complexity is None:
self._cumulative_node_complexity = super().cumulative_node_complexity
if self._merge_projection_complexity_with_subquery:
# if _merge_projection_complexity_with_subquery is true, the subquery
# projection complexity has already been merged with the current projection
# complexity, and we need to adjust the cumulative_node_complexity by
# subtracting the from_ projection complexity.
assert isinstance(self.from_, SelectStatement)
self._cumulative_node_complexity = subtract_complexities(
self._cumulative_node_complexity,
sum_node_complexities(*self.from_.projection_complexities),
)

return self._cumulative_node_complexity

@cumulative_node_complexity.setter
def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
self._cumulative_node_complexity = value

@property
def referenced_ctes(self) -> Set[str]:
return self.from_.referenced_ctes
Expand All @@ -913,6 +935,82 @@ def to_subqueryable(self) -> "Selectable":
return new
return self

def get_projection_name_complexity_map(
self,
) -> Optional[Dict[str, Dict[PlanNodeCategory, int]]]:
"""
Get a map between the projection column name and its complexity. If name or
projection complexity is missing for any column, None is returned.
"""
if (
(not self._column_states)
or (not self.projection)
or (not self._column_states.projection)
):
return None

if len(self.projection) != len(self._column_states.projection):
return None

projection_complexities = self.projection_complexities
if len(self._column_states.projection) != len(projection_complexities):
return None
else:
return {
attribute.name: complexity
for complexity, attribute in zip(
projection_complexities, self._column_states.projection
)
}

@property
def projection_complexities(self) -> List[Dict[PlanNodeCategory, int]]:
"""
Return the cumulative complexity for each projection expression. The
complexity is merged with the subquery projection complexity if
_merge_projection_complexity_with_subquery is True.
"""
if self.projection is None:
return []

if self._projection_complexities is None:
if self._merge_projection_complexity_with_subquery:
assert isinstance(
self.from_, SelectStatement
), "merge with none SelectStatement is not valid"
subquery_projection_name_complexity_map = (
self.from_.get_projection_name_complexity_map()
)
assert (
subquery_projection_name_complexity_map is not None
), "failed to extract dependent column map from subquery"
self._projection_complexities = []
for proj in self.projection:
# For a projection expression that dependents on columns [col1, col2, col1],
# and whose original cumulative_node_complexity is proj_complexity, the
# new complexity can be calculated as
# proj_complexity - {PlanNodeCategory.COLUMN: 1} + col1_complexity
# - {PlanNodeCategory.COLUMN: 1} + col2_complexity
# - {PlanNodeCategory.COLUMN: 1} + col1_complexity
dependent_columns = proj.dependent_column_names_with_duplication()
projection_complexity = proj.cumulative_node_complexity
for dependent_column in dependent_columns:
dependent_column_complexity = (
subquery_projection_name_complexity_map[dependent_column]
)
projection_complexity[PlanNodeCategory.COLUMN] -= 1
projection_complexity = sum_node_complexities(
projection_complexity, dependent_column_complexity
)

self._projection_complexities.append(projection_complexity)
else:
self._projection_complexities = [
expr.cumulative_node_complexity for expr in self.projection
]

return self._projection_complexities

def select(self, cols: List[Expression]) -> "SelectStatement":
"""Build a new query. This SelectStatement will be the subquery of the new query.
Possibly flatten the new query and the subquery (self) to form a new flattened query.
Expand Down
18 changes: 13 additions & 5 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
sum_node_complexities,
)
from snowflake.snowpark._internal.analyzer.table_function import (
GeneratorTableFunction,
Expand Down Expand Up @@ -441,16 +440,25 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
@property
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self._cumulative_node_complexity is None:
self._cumulative_node_complexity = sum_node_complexities(
self.individual_node_complexity,
*(node.cumulative_node_complexity for node in self.children_plan_nodes),
)
# if source plan is available, the source plan complexity
# is the snowflake plan complexity.
if self.source_plan:
self._cumulative_node_complexity = (
self.source_plan.cumulative_node_complexity
)
else:
self._cumulative_node_complexity = {}
return self._cumulative_node_complexity

@cumulative_node_complexity.setter
def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
self._cumulative_node_complexity = value

def reset_cumulative_node_complexity(self) -> None:
self._cumulative_node_complexity = None
if self.source_plan:
self.source_plan.reset_cumulative_node_complexity()

def __copy__(self) -> "SnowflakePlan":
if self.session._cte_optimization_enabled:
return SnowflakePlan(
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta

elif isinstance(parent, SelectStatement):
parent.from_ = to_selectable(new_child, query_generator)
# once the subquery is updated, set _merge_projection_complexity_with_subquery to False to
# disable the projection complexity merge
parent._merge_projection_complexity_with_subquery = False

elif isinstance(parent, SetStatement):
new_child_as_selectable = to_selectable(new_child, query_generator)
Expand Down Expand Up @@ -235,6 +238,8 @@ def update_resolvable_node(
# the projection expression can be re-analyzed during code generation
node._projection_in_str = None
node.analyzer = query_generator
# reset the _projection_complexities fields to re-calculate the complexities
node._projection_complexities = None

# update the pre_actions and post_actions for the select statement
node.pre_actions = node.from_.pre_actions
Expand Down
111 changes: 65 additions & 46 deletions tests/integ/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
random_name_for_temp_object,
)
from snowflake.snowpark.column import CaseExpr, Column
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.functions import col, lit, seq1, uniform
from tests.utils import Utils

pytestmark = [
pytest.mark.xfail(
Expand All @@ -50,6 +52,57 @@
]


def create_df_with_deep_nested_with_column_dependencies(
session, temp_table_name, nest_level: int
) -> DataFrame:
"""
This creates a sample table with 1
"""
# create a tabel with 11 columns (1 int columns and 10 string columns) for testing
struct_fields = [T.StructField("intCol", T.IntegerType(), True)]
for i in range(1, 11):
struct_fields.append(T.StructField(f"col{i}", T.StringType(), True))
schema = T.StructType(struct_fields)

Utils.create_table(
session, temp_table_name, attribute_to_schema_string(schema), is_temporary=True
)

df = session.table(temp_table_name)

def get_col_ref_expression(iter_num: int, col_func: Callable) -> Column:
ref_cols = [F.lit(str(iter_num))]
for i in range(1, 5):
col_name = f"col{i}"
ref_col = col_func(df[col_name])
ref_cols.append(ref_col)
return F.concat(*ref_cols)

for i in range(1, nest_level):
int_col = df["intCol"]
col1_base = get_col_ref_expression(i, F.initcap)
case_expr: Optional[CaseExpr] = None
# generate the condition expression based on the number of conditions
for j in range(1, 3):
if j == 1:
cond_col = int_col < 100
col_ref_expr = get_col_ref_expression(i, F.upper)
else:
cond_col = int_col < 300
col_ref_expr = get_col_ref_expression(i, F.lower)
case_expr = (
F.when(cond_col, col_ref_expr)
if case_expr is None
else case_expr.when(cond_col, col_ref_expr)
)

col1 = case_expr.otherwise(col1_base)

df = df.with_columns(["col1"], [col1])

return df


def verify_column_state(
copied_state: ColumnStateDict, original_state: ColumnStateDict
) -> None:
Expand Down Expand Up @@ -314,49 +367,15 @@ def test_create_or_replace_view(session):


def test_deep_nested_select(session):
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
# create a tabel with 11 columns (1 int columns and 10 string columns) for testing
struct_fields = [T.StructField("intCol", T.IntegerType(), True)]
for i in range(1, 11):
struct_fields.append(T.StructField(f"col{i}", T.StringType(), True))
schema = T.StructType(struct_fields)
session.sql(
f"create temp table {temp_table_name}({attribute_to_schema_string(schema)})"
).collect()
df = session.table(temp_table_name)

def get_col_ref_expression(iter_num: int, col_func: Callable) -> Column:
ref_cols = [F.lit(str(iter_num))]
for i in range(1, 5):
col_name = f"col{i}"
ref_col = col_func(df[col_name])
ref_cols.append(ref_col)
return F.concat(*ref_cols)

for i in range(1, 20):
int_col = df["intCol"]
col1_base = get_col_ref_expression(i, F.initcap)
case_expr: Optional[CaseExpr] = None
# generate the condition expression based on the number of conditions
for j in range(1, 3):
if j == 1:
cond_col = int_col < 100
col_ref_expr = get_col_ref_expression(i, F.upper)
else:
cond_col = int_col < 300
col_ref_expr = get_col_ref_expression(i, F.lower)
case_expr = (
F.when(cond_col, col_ref_expr)
if case_expr is None
else case_expr.when(cond_col, col_ref_expr)
)

col1 = case_expr.otherwise(col1_base)

df = df.with_columns(["col1"], [col1])

# make a copy of the final df plan
copied_plan = copy.deepcopy(df._plan)
# skip the checking of plan attribute for this plan, because the plan is complicated for
# compilation, and attribute issues describing call which will timeout during server compilation.
check_copied_plan(copied_plan, df._plan, skip_attribute=True)
temp_table_name = Utils.random_table_name()
try:
df = create_df_with_deep_nested_with_column_dependencies(
session, temp_table_name, 20
)
# make a copy of the final df plan
copied_plan = copy.deepcopy(df._plan)
# skip the checking of plan attribute for this plan, because the plan is complicated for
# compilation, and attribute issues describing call which will timeout during server compilation.
check_copied_plan(copied_plan, df._plan, skip_attribute=True)
finally:
Utils.drop_table(session, temp_table_name)
Loading

0 comments on commit e8765d2

Please sign in to comment.