From e8765d2c3ecf8a8b250c62dbd6e80314cfe237a6 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 27 Sep 2024 18:22:31 -0700 Subject: [PATCH] [SNOW-1632898] Adjust SelectStatement projection complexity calculation (#2340) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. SNOW-1632898 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. 1) adjust the projection complexity calculation, and cumulative complexity for selectStatement when _merge_projection_complexity_with_subquery set to True. 2) update the complexity calculation for snowflake plan to directly get the complexity from source plan, and update the reset to reset the cumulative complexity. --- .../analyzer/query_plan_analysis_utils.py | 18 + .../_internal/analyzer/select_statement.py | 120 +++++- .../_internal/analyzer/snowflake_plan.py | 18 +- .../snowpark/_internal/compiler/utils.py | 5 + tests/integ/test_deepcopy.py | 111 ++--- tests/integ/test_large_query_breakdown.py | 26 +- .../integ/test_nested_select_plan_analysis.py | 390 ++++++++++++++++-- tests/integ/test_query_plan_analysis.py | 14 +- tests/unit/test_query_plan_analysis.py | 135 +++++- 9 files changed, 719 insertions(+), 118 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index bb39dfd7be8..5b53f1f8080 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -51,6 +51,24 @@ def sum_node_complexities( return dict(counter_sum) +def subtract_complexities( + complexities1: Dict[PlanNodeCategory, int], + complexities2: Dict[PlanNodeCategory, int], +) -> Dict[PlanNodeCategory, int]: + """ + This is a helper function for complexities1 - complexities2. + """ + + result_complexities = complexities1.copy() + for key, value in complexities2.items(): + if key in result_complexities: + result_complexities[key] -= value + else: + result_complexities[key] = -value + + return result_complexities + + def get_complexity_score( cumulative_node_complexity: Dict[PlanNodeCategory, int] ) -> int: diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 9bfb37152a1..ca1ca485416 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -25,6 +25,7 @@ from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, + subtract_complexities, sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( @@ -656,6 +657,11 @@ def __init__( # _merge_projection_complexity_with_subquery is used to indicate that it is valid to merge # the projection complexity of current SelectStatement with subquery. self._merge_projection_complexity_with_subquery = False + # cached list of projection complexities, each projection complexity is adjusted + # with the subquery projection if _merge_projection_complexity_with_subquery is True. + self._projection_complexities: Optional[ + List[Dict[PlanNodeCategory, int]] + ] = None def __copy__(self): new = SelectStatement( @@ -704,6 +710,11 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006 copied._merge_projection_complexity_with_subquery = ( self._merge_projection_complexity_with_subquery ) + copied._projection_complexities = ( + deepcopy(self._projection_complexities) + if not self._projection_complexities + else None + ) return copied @property @@ -844,17 +855,7 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} # projection component complexity = ( - sum_node_complexities( - complexity, - *( - getattr( - expr, - "cumulative_node_complexity", - {PlanNodeCategory.COLUMN: 1}, - ) # type: ignore - for expr in self.projection - ), - ) + sum_node_complexities(*self.projection_complexities) if self.projection else complexity ) @@ -894,6 +895,27 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: ) return complexity + @property + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self._cumulative_node_complexity is None: + self._cumulative_node_complexity = super().cumulative_node_complexity + if self._merge_projection_complexity_with_subquery: + # if _merge_projection_complexity_with_subquery is true, the subquery + # projection complexity has already been merged with the current projection + # complexity, and we need to adjust the cumulative_node_complexity by + # subtracting the from_ projection complexity. + assert isinstance(self.from_, SelectStatement) + self._cumulative_node_complexity = subtract_complexities( + self._cumulative_node_complexity, + sum_node_complexities(*self.from_.projection_complexities), + ) + + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): + self._cumulative_node_complexity = value + @property def referenced_ctes(self) -> Set[str]: return self.from_.referenced_ctes @@ -913,6 +935,82 @@ def to_subqueryable(self) -> "Selectable": return new return self + def get_projection_name_complexity_map( + self, + ) -> Optional[Dict[str, Dict[PlanNodeCategory, int]]]: + """ + Get a map between the projection column name and its complexity. If name or + projection complexity is missing for any column, None is returned. + """ + if ( + (not self._column_states) + or (not self.projection) + or (not self._column_states.projection) + ): + return None + + if len(self.projection) != len(self._column_states.projection): + return None + + projection_complexities = self.projection_complexities + if len(self._column_states.projection) != len(projection_complexities): + return None + else: + return { + attribute.name: complexity + for complexity, attribute in zip( + projection_complexities, self._column_states.projection + ) + } + + @property + def projection_complexities(self) -> List[Dict[PlanNodeCategory, int]]: + """ + Return the cumulative complexity for each projection expression. The + complexity is merged with the subquery projection complexity if + _merge_projection_complexity_with_subquery is True. + """ + if self.projection is None: + return [] + + if self._projection_complexities is None: + if self._merge_projection_complexity_with_subquery: + assert isinstance( + self.from_, SelectStatement + ), "merge with none SelectStatement is not valid" + subquery_projection_name_complexity_map = ( + self.from_.get_projection_name_complexity_map() + ) + assert ( + subquery_projection_name_complexity_map is not None + ), "failed to extract dependent column map from subquery" + self._projection_complexities = [] + for proj in self.projection: + # For a projection expression that dependents on columns [col1, col2, col1], + # and whose original cumulative_node_complexity is proj_complexity, the + # new complexity can be calculated as + # proj_complexity - {PlanNodeCategory.COLUMN: 1} + col1_complexity + # - {PlanNodeCategory.COLUMN: 1} + col2_complexity + # - {PlanNodeCategory.COLUMN: 1} + col1_complexity + dependent_columns = proj.dependent_column_names_with_duplication() + projection_complexity = proj.cumulative_node_complexity + for dependent_column in dependent_columns: + dependent_column_complexity = ( + subquery_projection_name_complexity_map[dependent_column] + ) + projection_complexity[PlanNodeCategory.COLUMN] -= 1 + projection_complexity = sum_node_complexities( + projection_complexity, dependent_column_complexity + ) + + self._projection_complexities.append(projection_complexity) + else: + self._projection_complexities = [ + expr.cumulative_node_complexity for expr in self.projection + ] + + return self._projection_complexities + def select(self, cols: List[Expression]) -> "SelectStatement": """Build a new query. This SelectStatement will be the subquery of the new query. Possibly flatten the new query and the subquery (self) to form a new flattened query. diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 9750deba3f9..9bd8bbd4ae1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -25,7 +25,6 @@ from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, @@ -441,16 +440,25 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: @property def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self._cumulative_node_complexity is None: - self._cumulative_node_complexity = sum_node_complexities( - self.individual_node_complexity, - *(node.cumulative_node_complexity for node in self.children_plan_nodes), - ) + # if source plan is available, the source plan complexity + # is the snowflake plan complexity. + if self.source_plan: + self._cumulative_node_complexity = ( + self.source_plan.cumulative_node_complexity + ) + else: + self._cumulative_node_complexity = {} return self._cumulative_node_complexity @cumulative_node_complexity.setter def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value + def reset_cumulative_node_complexity(self) -> None: + self._cumulative_node_complexity = None + if self.source_plan: + self.source_plan.reset_cumulative_node_complexity() + def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: return SnowflakePlan( diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 4ec6980f22a..b9eaf40ce93 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -136,6 +136,9 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta elif isinstance(parent, SelectStatement): parent.from_ = to_selectable(new_child, query_generator) + # once the subquery is updated, set _merge_projection_complexity_with_subquery to False to + # disable the projection complexity merge + parent._merge_projection_complexity_with_subquery = False elif isinstance(parent, SetStatement): new_child_as_selectable = to_selectable(new_child, query_generator) @@ -235,6 +238,8 @@ def update_resolvable_node( # the projection expression can be re-analyzed during code generation node._projection_in_str = None node.analyzer = query_generator + # reset the _projection_complexities fields to re-calculate the complexities + node._projection_complexities = None # update the pre_actions and post_actions for the select statement node.pre_actions = node.from_.pre_actions diff --git a/tests/integ/test_deepcopy.py b/tests/integ/test_deepcopy.py index 93d0a49fbf2..5f60b4d5a06 100644 --- a/tests/integ/test_deepcopy.py +++ b/tests/integ/test_deepcopy.py @@ -39,7 +39,9 @@ random_name_for_temp_object, ) from snowflake.snowpark.column import CaseExpr, Column +from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.functions import col, lit, seq1, uniform +from tests.utils import Utils pytestmark = [ pytest.mark.xfail( @@ -50,6 +52,57 @@ ] +def create_df_with_deep_nested_with_column_dependencies( + session, temp_table_name, nest_level: int +) -> DataFrame: + """ + This creates a sample table with 1 + """ + # create a tabel with 11 columns (1 int columns and 10 string columns) for testing + struct_fields = [T.StructField("intCol", T.IntegerType(), True)] + for i in range(1, 11): + struct_fields.append(T.StructField(f"col{i}", T.StringType(), True)) + schema = T.StructType(struct_fields) + + Utils.create_table( + session, temp_table_name, attribute_to_schema_string(schema), is_temporary=True + ) + + df = session.table(temp_table_name) + + def get_col_ref_expression(iter_num: int, col_func: Callable) -> Column: + ref_cols = [F.lit(str(iter_num))] + for i in range(1, 5): + col_name = f"col{i}" + ref_col = col_func(df[col_name]) + ref_cols.append(ref_col) + return F.concat(*ref_cols) + + for i in range(1, nest_level): + int_col = df["intCol"] + col1_base = get_col_ref_expression(i, F.initcap) + case_expr: Optional[CaseExpr] = None + # generate the condition expression based on the number of conditions + for j in range(1, 3): + if j == 1: + cond_col = int_col < 100 + col_ref_expr = get_col_ref_expression(i, F.upper) + else: + cond_col = int_col < 300 + col_ref_expr = get_col_ref_expression(i, F.lower) + case_expr = ( + F.when(cond_col, col_ref_expr) + if case_expr is None + else case_expr.when(cond_col, col_ref_expr) + ) + + col1 = case_expr.otherwise(col1_base) + + df = df.with_columns(["col1"], [col1]) + + return df + + def verify_column_state( copied_state: ColumnStateDict, original_state: ColumnStateDict ) -> None: @@ -314,49 +367,15 @@ def test_create_or_replace_view(session): def test_deep_nested_select(session): - temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) - # create a tabel with 11 columns (1 int columns and 10 string columns) for testing - struct_fields = [T.StructField("intCol", T.IntegerType(), True)] - for i in range(1, 11): - struct_fields.append(T.StructField(f"col{i}", T.StringType(), True)) - schema = T.StructType(struct_fields) - session.sql( - f"create temp table {temp_table_name}({attribute_to_schema_string(schema)})" - ).collect() - df = session.table(temp_table_name) - - def get_col_ref_expression(iter_num: int, col_func: Callable) -> Column: - ref_cols = [F.lit(str(iter_num))] - for i in range(1, 5): - col_name = f"col{i}" - ref_col = col_func(df[col_name]) - ref_cols.append(ref_col) - return F.concat(*ref_cols) - - for i in range(1, 20): - int_col = df["intCol"] - col1_base = get_col_ref_expression(i, F.initcap) - case_expr: Optional[CaseExpr] = None - # generate the condition expression based on the number of conditions - for j in range(1, 3): - if j == 1: - cond_col = int_col < 100 - col_ref_expr = get_col_ref_expression(i, F.upper) - else: - cond_col = int_col < 300 - col_ref_expr = get_col_ref_expression(i, F.lower) - case_expr = ( - F.when(cond_col, col_ref_expr) - if case_expr is None - else case_expr.when(cond_col, col_ref_expr) - ) - - col1 = case_expr.otherwise(col1_base) - - df = df.with_columns(["col1"], [col1]) - - # make a copy of the final df plan - copied_plan = copy.deepcopy(df._plan) - # skip the checking of plan attribute for this plan, because the plan is complicated for - # compilation, and attribute issues describing call which will timeout during server compilation. - check_copied_plan(copied_plan, df._plan, skip_attribute=True) + temp_table_name = Utils.random_table_name() + try: + df = create_df_with_deep_nested_with_column_dependencies( + session, temp_table_name, 20 + ) + # make a copy of the final df plan + copied_plan = copy.deepcopy(df._plan) + # skip the checking of plan attribute for this plan, because the plan is complicated for + # compilation, and attribute issues describing call which will timeout during server compilation. + check_copied_plan(copied_plan, df._plan, skip_attribute=True) + finally: + Utils.drop_table(session, temp_table_name) diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index e52eadd9421..31171792df3 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -35,7 +35,7 @@ def large_query_df(session): df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) - for i in range(100): + for i in range(110): df1 = df1.with_column("A", col("A") + lit(i)) df2 = df2.with_column("B", col("B") + lit(i)) df1 = df1.group_by(col("A")).agg(sum_distinct(col("B")).alias("B")) @@ -93,7 +93,7 @@ def test_no_valid_nodes_found(session, large_query_df, caplog): df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) - for i in range(102): + for i in range(160): df1 = df1.with_column("A", col("A") + lit(i)) df2 = df2.with_column("B", col("B") + lit(i)) @@ -116,9 +116,9 @@ def test_large_query_breakdown_with_cte_optimization(session): df2 = df1.filter(col("b") == 2).union_all(df1) df3 = df1.with_column("a", col("a") + 1) - for i in range(100): - df2 = df2.with_column("a", col("a") + i) - df3 = df3.with_column("b", col("b") + i) + for i in range(7): + df2 = df2.with_column("a", col("a") + i + col("a")) + df3 = df3.with_column("b", col("b") + i + col("b")) df2 = df2.group_by("a").agg(sum_distinct(col("b")).alias("b")) df3 = df3.group_by("b").agg(sum_distinct(col("a")).alias("a")) @@ -224,9 +224,9 @@ def test_pivot_unpivot(session): schema=["A", "dept", "jan", "feb"], ) - for i in range(100): - df_pivot = df_pivot.with_column("A", col("A") + lit(i)) - df_unpivot = df_unpivot.with_column("A", col("A") + lit(i)) + for i in range(6): + df_pivot = df_pivot.with_column("A", col("A") + lit(i) + col("A")) + df_unpivot = df_unpivot.with_column("A", col("A") + lit(i) + col("A")) df_pivot = df_pivot.pivot("month", ["JAN", "FEB"]).sum("B") df_unpivot = df_unpivot.unpivot("sales", "month", ["jan", "feb"]) @@ -249,7 +249,7 @@ def test_sort(session): df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) - for i in range(100): + for i in range(160): df1 = df1.with_column("A", col("A") + lit(i)) df2 = df2.with_column("B", col("B") + lit(i)) df1_with_sort = df1.order_by("A") @@ -285,7 +285,7 @@ def test_multiple_query_plan(session, large_query_df): df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) - for i in range(100): + for i in range(160): df1 = df1.with_column("A", col("A") + lit(i)) df2 = df2.with_column("B", col("B") + lit(i)) df1 = df1.group_by(col("A")).agg(sum_distinct(col("B")).alias("B")) @@ -377,7 +377,7 @@ def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): session._conn, "run_query", wraps=session._conn.run_query ) as patched_run_query: result = large_query_df.collect() - Utils.check_answer(result, [Row(1, 4954), Row(2, 4953)]) + Utils.check_answer(result, [Row(1, 5999), Row(2, 5998)]) plan = large_query_df._plan # 1 for current transaction, 1 for partition, 1 for main query, 1 for post action @@ -404,8 +404,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): assert large_query_df.queries["post_actions"][0].startswith( "DROP TABLE If EXISTS" ) - - set_bounds(session, 300, 412) + set_bounds(session, 300, 455) assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 assert large_query_df.queries["queries"][0].startswith( @@ -420,7 +419,6 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): assert large_query_df.queries["post_actions"][1].startswith( "DROP TABLE If EXISTS" ) - set_bounds(session, 0, 300) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 diff --git a/tests/integ/test_nested_select_plan_analysis.py b/tests/integ/test_nested_select_plan_analysis.py index ee38aac5f50..81fb95afc5a 100644 --- a/tests/integ/test_nested_select_plan_analysis.py +++ b/tests/integ/test_nested_select_plan_analysis.py @@ -2,8 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from typing import Dict, Optional + import pytest +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) from snowflake.snowpark._internal.analyzer.select_statement import SelectStatement from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.functions import ( @@ -19,6 +24,11 @@ min as min_, ) from snowflake.snowpark.window import Window +from tests.integ.test_deepcopy import ( + create_df_with_deep_nested_with_column_dependencies, +) +from tests.integ.test_query_plan_analysis import assert_df_subtree_query_complexity +from tests.utils import Utils pytestmark = [ pytest.mark.xfail( @@ -28,6 +38,7 @@ ) ] + paramList = [False, True] @@ -44,51 +55,161 @@ def setup(request, session): @pytest.fixture(scope="function") def simple_dataframe(session) -> DataFrame: + """ + The complexity of the simple_dataframe is {COLUMN: 6, LITERAL: 9}, and corresponds to the following query: + + SELECT "A", "B", "C" FROM ( + SELECT $1 AS "A", $2 AS "B", $3 AS "C" FROM VALUES ( + 1 :: INT, \'a\' :: STRING, 2 :: INT), (2 :: INT, \'b\' :: STRING, 3 :: INT), (3 :: INT, \'c\' :: STRING, 7 :: INT)) + """ return session.create_dataframe( [[1, "a", 2], [2, "b", 3], [3, "c", 7]], schema=["a", "b", "c"] ) +@pytest.fixture(scope="function") +def sample_table(session): + table_name = Utils.random_table_name() + Utils.create_table( + session, table_name, "a int, b int, c int, d int", is_temporary=True + ) + session._run_query( + f"insert into {table_name}(a, b, c, d) values " "(1, 2, 3, 4), (5, 6, 7, 8)" + ) + yield table_name + Utils.drop_table(session, table_name) + + def verify_dataframe_select_statement( - df: DataFrame, can_be_merged_when_enabled: bool + df: DataFrame, + can_be_merged_when_enabled: bool, + complexity_before_merge: Dict[PlanNodeCategory, int], + complexity_after_merge: Optional[Dict[PlanNodeCategory, int]] = None, ) -> None: assert isinstance(df._plan.source_plan, SelectStatement) if not df.session.large_query_breakdown_enabled: # if large query breakdown is disabled, _merge_projection_complexity_with_subquery will always be false assert df._plan.source_plan._merge_projection_complexity_with_subquery is False + assert_df_subtree_query_complexity(df, complexity_before_merge) else: assert ( df._plan.source_plan._merge_projection_complexity_with_subquery == can_be_merged_when_enabled ) + if can_be_merged_when_enabled: + assert ( + complexity_after_merge + ), "no complexity after merge is provided for validation" + assert_df_subtree_query_complexity(df, complexity_after_merge) + else: + assert_df_subtree_query_complexity(df, complexity_before_merge) def test_simple_valid_nested_select(simple_dataframe): df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( (col("a") + 3).as_("a"), "c" ) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LITERAL: 11, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.LITERAL: 11, + PlanNodeCategory.COLUMN: 5, + }, + ) # add one more select df_res = df_res.select(col("a") * 2, (col("c") + 2).as_("d")) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 4, + PlanNodeCategory.COLUMN: 10, + PlanNodeCategory.LITERAL: 13, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 4, + PlanNodeCategory.LITERAL: 13, + PlanNodeCategory.COLUMN: 5, + }, + ) def test_nested_select_with_star(simple_dataframe): df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select("*") + print(df_res._plan.cumulative_node_complexity) # star will be automatically flattened, the complexity won't be flattened - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.LITERAL: 10, + }, + ) df_res = df_res.select((col("a") + 3).as_("a"), "c") - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LITERAL: 11, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.LITERAL: 11, + PlanNodeCategory.COLUMN: 5, + }, + ) def test_nested_select_with_valid_function_expressions(simple_dataframe): df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( concat("a", "b").as_("a"), initcap("c").as_("c"), "b" ) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.COLUMN: 10, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.LITERAL: 10, + }, + complexity_after_merge={ + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.COLUMN: 7, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.LITERAL: 10, + }, + ) + df_res = df_res.select(concat("a", initcap(concat("b", "c"))), add_months("a", 5)) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.FUNCTION: 6, + PlanNodeCategory.COLUMN: 14, + PlanNodeCategory.LITERAL: 11, + PlanNodeCategory.LOW_IMPACT: 1, + }, + complexity_after_merge={ + PlanNodeCategory.FUNCTION: 7, + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.LITERAL: 12, + }, + ) def test_nested_select_with_window_functions(simple_dataframe): @@ -99,7 +220,21 @@ def test_nested_select_with_window_functions(simple_dataframe): df_res = simple_dataframe.select( avg("a").over(window1).as_("a"), avg("b").over(window2).as_("b") ).select((col("a") + 1).as_("a"), "b") - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 6, + PlanNodeCategory.COLUMN: 10, + PlanNodeCategory.LITERAL: 11, + PlanNodeCategory.WINDOW: 2, + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.PARTITION_BY: 1, + PlanNodeCategory.ORDER_BY: 2, + PlanNodeCategory.OTHERS: 2, + }, + ) def test_nested_select_with_table_functions(session): @@ -110,7 +245,15 @@ def test_nested_select_with_table_functions(session): ) df_res = df.select((col("a") + 1).as_("a"), "b", "c") - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.COLUMN: 3, + PlanNodeCategory.LITERAL: 1, + }, + ) def test_nested_select_with_valid_builtin_function(simple_dataframe): @@ -118,65 +261,129 @@ def test_nested_select_with_valid_builtin_function(simple_dataframe): builtin("nvl")(col("a"), col("b")).as_("a"), builtin("nvl2")(col("b"), col("c")).as_("c"), ) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.COLUMN: 10, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.LITERAL: 10, + }, + complexity_after_merge={ + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.COLUMN: 7, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.LITERAL: 10, + }, + ) def test_nested_select_with_agg_functions(simple_dataframe): df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( avg("a").as_("a"), min_("c").as_("c") ) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) - df_res = simple_dataframe.select(max_("a")) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) - - -def test_nested_select_with_limit_filter_order_by(simple_dataframe): - """ - df_res_filtered = ( - simple_dataframe.filter(col("a") == 1) - .select((col("a") + 1).as_("a"), "b", "c") - .select((col("a") + 1).as_("a"), "b") + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.LITERAL: 10, + }, ) - verify_dataframe_select_statement(df_res_filtered, can_be_merged_when_enabled=False) - df_res_limit = ( - simple_dataframe.select((col("a") + 1).as_("a"), "b", "c") - .limit(10, 5) - .select(concat("a", "b").as_("a"), initcap("c").as_("c"), "b") + df_res = simple_dataframe.select( + max_("a").as_("a"), (min_("c") + 1).as_("c") + ).select(col("a") + 1, "c") + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.COLUMN: 7, + PlanNodeCategory.LITERAL: 11, + PlanNodeCategory.FUNCTION: 2, + }, ) - verify_dataframe_select_statement(df_res_limit, can_be_merged_when_enabled=False) - """ + +def test_nested_select_with_limit_filter_order_by(simple_dataframe): def_order_by_filter = ( simple_dataframe.select((col("a") + 1).as_("a"), "b", "c") .order_by(col("a")) .filter(col("a") == 1) ) df_res = def_order_by_filter.select((col("a") + 2).as_("a")) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LITERAL: 12, + PlanNodeCategory.FILTER: 1, + PlanNodeCategory.OTHERS: 1, + PlanNodeCategory.ORDER_BY: 1, + }, + ) def test_select_with_dependency_within_same_level(simple_dataframe): df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( (col("a") + 2).as_("d"), (col("d") + 1).as_("e") ) - # star will be automatically flattened, the complexity won't be flattened - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=False) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LITERAL: 12, + }, + ) def test_select_with_duplicated_columns(simple_dataframe): - def_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( + df_res = simple_dataframe.select((col("a") + 1).as_("a"), "b", "c").select( (col("a") + 2).as_("b"), (col("b") + 1).as_("b") ) - verify_dataframe_select_statement(def_res, can_be_merged_when_enabled=True) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LITERAL: 12, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.LITERAL: 12, + PlanNodeCategory.COLUMN: 5, + }, + ) def test_select_with_dollar_dependency(simple_dataframe): - def_res = simple_dataframe.select((col("a") + 1), "b", "c").select( + df_res = simple_dataframe.select((col("a") + 1), "b", "c").select( (col("$1") + 2).as_("b"), col("$2").as_("c") ) - verify_dataframe_select_statement(def_res, can_be_merged_when_enabled=False) + + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.LITERAL: 11, + }, + ) def test_valid_after_invalid_nested_select(simple_dataframe): @@ -185,7 +392,116 @@ def test_valid_after_invalid_nested_select(simple_dataframe): .select((col("a") + 1).as_("a"), "b", "c") .select((col("a") + 1).as_("a"), "b") ) - verify_dataframe_select_statement(df_res_filtered, can_be_merged_when_enabled=False) + print(df_res_filtered._plan.cumulative_node_complexity) + verify_dataframe_select_statement( + df_res_filtered, + can_be_merged_when_enabled=False, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LITERAL: 12, + PlanNodeCategory.FILTER: 1, + }, + ) df_res = df_res_filtered.select((col("a") + 2).as_("a"), (col("b") + 2).as_("b")) - verify_dataframe_select_statement(df_res, can_be_merged_when_enabled=True) + print(df_res._plan.cumulative_node_complexity) + verify_dataframe_select_statement( + df_res, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 5, + PlanNodeCategory.COLUMN: 11, + PlanNodeCategory.LITERAL: 14, + PlanNodeCategory.FILTER: 1, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 5, + PlanNodeCategory.LITERAL: 14, + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.FILTER: 1, + }, + ) + + +def test_simple_nested_select_with_repeated_column_dependency(session, sample_table): + df = session.table(sample_table) + df_select = df.select((col("a") + 1).as_("a"), "b", "c") + assert_df_subtree_query_complexity( + df_select, + { + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 1, + }, + ) + + df_select = df_select.select((col("a") + 3).as_("a"), "c") + # the two select complexity can be merged when large query breakdown enabled, + # and will be equivalent to the complexity of + # df.select((col("a") + 1 + 3).as_("a"), "c") + verify_dataframe_select_statement( + df_select, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LITERAL: 2, + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.LOW_IMPACT: 2, + }, + complexity_after_merge={ + PlanNodeCategory.LITERAL: 2, + PlanNodeCategory.COLUMN: 3, + PlanNodeCategory.LOW_IMPACT: 2, + }, + ) + + # add one more select with duplicated reference + df_select = df_select.select( + col("a") * 2 + col("a") + col("c"), (col("c") + 2).as_("d") + ) + print(df_select._plan.cumulative_node_complexity) + # the complexity can be continue merged with the previous select, and the whole tree complexity + # will be equivalent to df.select((col("a") + 3 + 1) * 2 + (col("a") + 3 + 1) + col("c"), (col("c") + 2).as_("d") + verify_dataframe_select_statement( + df_select, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.LOW_IMPACT: 6, + PlanNodeCategory.COLUMN: 10, + PlanNodeCategory.LITERAL: 4, + }, + complexity_after_merge={ + PlanNodeCategory.LOW_IMPACT: 8, + PlanNodeCategory.COLUMN: 5, + PlanNodeCategory.LITERAL: 6, + }, + ) + + +def test_deep_nested_with_columns(session): + temp_table_name = Utils.random_table_name() + try: + df = create_df_with_deep_nested_with_column_dependencies( + session, temp_table_name, 5 + ) + verify_dataframe_select_statement( + df, + can_be_merged_when_enabled=True, + complexity_before_merge={ + PlanNodeCategory.COLUMN: 97, + PlanNodeCategory.CASE_WHEN: 4, + PlanNodeCategory.LOW_IMPACT: 8, + PlanNodeCategory.LITERAL: 20, + PlanNodeCategory.FUNCTION: 60, + }, + complexity_after_merge={ + PlanNodeCategory.COLUMN: 389, + PlanNodeCategory.CASE_WHEN: 40, + PlanNodeCategory.LOW_IMPACT: 80, + PlanNodeCategory.LITERAL: 200, + PlanNodeCategory.FUNCTION: 600, + }, + ) + print(df._plan.cumulative_node_complexity) + finally: + Utils.drop_table(session, temp_table_name) diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 8a0fa98a1ad..960eb0c4fca 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -32,12 +32,18 @@ ] -@pytest.fixture(autouse=True) -def setup(session): +paramList = [False, True] + + +@pytest.fixture(params=paramList, autouse=True) +def setup(request, session): is_simplifier_enabled = session._sql_simplifier_enabled + large_query_breakdown_enabled = session.large_query_breakdown_enabled + session.large_query_breakdown_enabled = request.param session._sql_simplifier_enabled = True yield session._sql_simplifier_enabled = is_simplifier_enabled + session.large_query_breakdown_enabled = large_query_breakdown_enabled @pytest.fixture(scope="module") @@ -57,7 +63,9 @@ def get_cumulative_node_complexity(df: DataFrame) -> Dict[str, int]: return df._plan.cumulative_node_complexity -def assert_df_subtree_query_complexity(df: DataFrame, estimate: Dict[str, int]): +def assert_df_subtree_query_complexity( + df: DataFrame, estimate: Dict[PlanNodeCategory, int] +): assert ( get_cumulative_node_complexity(df) == estimate ), f"query = {df.queries['queries'][-1]}" diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index 541997f658c..f10fe2f02bc 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -12,11 +12,17 @@ UNION, UNION_ALL, ) -from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + Expression, + NamedExpression, +) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, + subtract_complexities, ) from snowflake.snowpark._internal.analyzer.select_statement import ( + ColumnStateDict, Selectable, SelectableEntity, SelectSnowflakePlan, @@ -33,6 +39,7 @@ ) from snowflake.snowpark._internal.analyzer.table_function import TableFunctionExpression from snowflake.snowpark._internal.analyzer.unary_plan_node import Project +from snowflake.snowpark.dataframe import StringType @pytest.mark.parametrize("node_type", [LogicalPlan, SnowflakePlan, Selectable]) @@ -157,7 +164,14 @@ def test_select_statement_individual_node_complexity( plan_node = SelectStatement(from_=from_, analyzer=mock_analyzer) setattr(plan_node, attribute, value) - assert plan_node.individual_node_complexity == expected_stat + if attribute == "projection" and isinstance(value[0], NamedExpression): + # NamedExpression is not a valid projection expression for selectStatement, + # and there is no individual_node_complexity or cumulative_node_complexity + # attributes associated with it + with pytest.raises(AttributeError): + plan_node.individual_node_complexity + else: + assert plan_node.individual_node_complexity == expected_stat def test_select_table_function_individual_node_complexity( @@ -191,3 +205,120 @@ def test_set_statement_individual_node_complexity(mock_analyzer, set_operator): plan_node = SetStatement(*set_operands, analyzer=mock_analyzer) assert plan_node.individual_node_complexity == {PlanNodeCategory.SET_OPERATION: 1} + + +@pytest.mark.parametrize( + "complexity1, complexity2, expected_result", + [ + ( + { + PlanNodeCategory.COLUMN: 20, + PlanNodeCategory.LITERAL: 5, + PlanNodeCategory.FUNCTION: 3, + }, + { + PlanNodeCategory.COLUMN: 11, + PlanNodeCategory.LITERAL: 4, + PlanNodeCategory.FUNCTION: 1, + }, + { + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.FUNCTION: 2, + }, + ), + ( + { + PlanNodeCategory.COLUMN: 20, + PlanNodeCategory.LITERAL: 5, + PlanNodeCategory.FUNCTION: 3, + }, + {PlanNodeCategory.COLUMN: 11, PlanNodeCategory.FUNCTION: 1}, + { + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LITERAL: 5, + PlanNodeCategory.FUNCTION: 2, + }, + ), + ( + { + PlanNodeCategory.COLUMN: 20, + PlanNodeCategory.LITERAL: 5, + PlanNodeCategory.FUNCTION: 3, + }, + {PlanNodeCategory.LITERAL: 11, PlanNodeCategory.FUNCTION: 1}, + { + PlanNodeCategory.COLUMN: 20, + PlanNodeCategory.LITERAL: -6, + PlanNodeCategory.FUNCTION: 2, + }, + ), + ( + { + PlanNodeCategory.COLUMN: 20, + PlanNodeCategory.LITERAL: 5, + PlanNodeCategory.FUNCTION: 3, + }, + { + PlanNodeCategory.COLUMN: 11, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.FILTER: 1, + PlanNodeCategory.CASE_WHEN: 2, + }, + { + PlanNodeCategory.COLUMN: 9, + PlanNodeCategory.LITERAL: 4, + PlanNodeCategory.FUNCTION: 3, + PlanNodeCategory.FILTER: -1, + PlanNodeCategory.CASE_WHEN: -2, + }, + ), + ], +) +def test_subtract_complexities(complexity1, complexity2, expected_result): + assert subtract_complexities(complexity1, complexity2) == expected_result + + +def test_select_statement_get_complexity_map_no_column_state(mock_analyzer): + mock_from = mock.create_autospec(Selectable) + mock_from.pre_actions = None + mock_from.post_actions = None + mock_from.expr_to_alias = {} + mock_from.df_aliased_col_name_to_real_col_name = {} + select_statement = SelectStatement(analyzer=mock_analyzer, from_=mock_from) + + assert select_statement.get_projection_name_complexity_map() is None + assert select_statement.projection_complexities == [] + + select_statement._column_states = mock.create_autospec(ColumnStateDict) + select_statement.projection = [Expression()] + mock_from._column_states = None + + assert select_statement.get_projection_name_complexity_map() is None + + +def test_select_statement_get_complexity_map_mismatch_projection_length(mock_analyzer): + mock_from = mock.create_autospec(Selectable) + mock_from.pre_actions = None + mock_from.post_actions = None + mock_from.expr_to_alias = {} + mock_from.df_aliased_col_name_to_real_col_name = {} + + # create a select_statement with 2 projections + select_statement = SelectStatement( + analyzer=mock_analyzer, projection=[Expression(), Expression()], from_=mock_from + ) + column_states = ColumnStateDict() + column_states.projection = [Attribute("A", StringType())] + select_statement._column_states = column_states + + assert select_statement.get_projection_name_complexity_map() is None + + # update column states projection length to match the projection length + column_states.projection = [ + Attribute("A", StringType()), + Attribute("B", StringType()), + ] + select_statement._projection_complexities = [] + + assert select_statement.get_projection_name_complexity_map() is None