From 6dfcd8c023bd0325b09ee4121f2f4af036db1f67 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 7 Nov 2024 11:08:06 -0800 Subject: [PATCH 1/3] make it easy to debug --- .../snowpark/_internal/analyzer/analyzer.py | 1 + .../_internal/analyzer/snowflake_plan.py | 1 + src/snowflake/snowpark/dataframe.py | 64 +++++-------------- 3 files changed, 18 insertions(+), 48 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 447ae13793f..9c95493b1b4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -807,6 +807,7 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: use_maps = {} # get counts of expr_to_alias keys counts = Counter() + # TODO join: key here? we are only keeping non-shared expr_to_alias keys for v in resolved_children.values(): if v.expr_to_alias: counts.update(list(v.expr_to_alias.keys())) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c907790449c..860907ba616 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -587,6 +587,7 @@ def build_binary( common_columns = set(select_left.expr_to_alias.keys()).intersection( select_right.expr_to_alias.keys() ) + # TODO join: related to alias new_expr_to_alias = { k: v for k, v in { diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index a87abf04c0c..7095a4cf1b4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -96,7 +96,6 @@ from snowflake.snowpark._internal.telemetry import ( add_api_call, adjust_api_subcalls, - df_api_usage, df_collect_api_telemetry, df_to_relational_group_df_api_usage, ) @@ -585,7 +584,6 @@ def collect( ) -> AsyncJob: ... # pragma: no cover - @df_collect_api_telemetry def collect( self, *, @@ -616,7 +614,6 @@ def collect( case_sensitive=case_sensitive, ) - @df_collect_api_telemetry def collect_nowait( self, *, @@ -674,7 +671,6 @@ def _internal_collect_with_tag_no_telemetry( _internal_collect_with_tag_no_telemetry ) - @df_collect_api_telemetry def _execute_and_get_query_id( self, *, statement_params: Optional[Dict[str, str]] = None ) -> str: @@ -709,7 +705,6 @@ def to_local_iterator( ) -> AsyncJob: ... # pragma: no cover - @df_collect_api_telemetry def to_local_iterator( self, *, @@ -786,7 +781,6 @@ def to_pandas( ) -> AsyncJob: ... # pragma: no cover - @df_collect_api_telemetry def to_pandas( self, *, @@ -862,7 +856,6 @@ def to_pandas_batches( ) -> AsyncJob: ... # pragma: no cover - @df_collect_api_telemetry def to_pandas_batches( self, *, @@ -913,7 +906,6 @@ def to_pandas_batches( **kwargs, ) - @df_api_usage def to_df(self, *names: Union[str, Iterable[str]]) -> "DataFrame": """ Creates a new DataFrame containing columns with the specified names. @@ -948,7 +940,6 @@ def to_df(self, *names: Union[str, Iterable[str]]) -> "DataFrame": new_cols.append(Column(attr).alias(name)) return self.select(new_cols) - @df_collect_api_telemetry def to_snowpark_pandas( self, index_col: Optional[Union[str, List[str]]] = None, @@ -1087,7 +1078,6 @@ def col(self, col_name: str) -> Column: else: return Column(self._resolve(col_name)) - @df_api_usage def select( self, *cols: Union[ @@ -1219,7 +1209,6 @@ def select( return self._with_plan(Project(names, join_plan or self._plan)) - @df_api_usage def select_expr(self, *exprs: Union[str, Iterable[str]]) -> "DataFrame": """ Projects a set of SQL expressions and returns a new :class:`DataFrame`. @@ -1251,7 +1240,6 @@ def select_expr(self, *exprs: Union[str, Iterable[str]]) -> "DataFrame": selectExpr = select_expr - @df_api_usage def drop( self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], @@ -1325,7 +1313,6 @@ def drop( else: return self.select(list(keep_col_names)) - @df_api_usage def filter(self, expr: ColumnOrSqlExpr) -> "DataFrame": """Filters rows based on the specified conditional expression (similar to WHERE in SQL). @@ -1359,7 +1346,6 @@ def filter(self, expr: ColumnOrSqlExpr) -> "DataFrame": ) ) - @df_api_usage def sort( self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], @@ -1497,7 +1483,6 @@ def alias(self, name: str): ] = attr.name return _copy - @df_api_usage def agg( self, *exprs: Union[Column, Tuple[ColumnOrName, str], Dict[str, str]], @@ -1688,7 +1673,6 @@ def cube( snowflake.snowpark.relational_grouped_dataframe._CubeType(), ) - @df_api_usage def distinct(self) -> "DataFrame": """Returns a new DataFrame that contains only the rows with distinct values from the current DataFrame. @@ -1813,7 +1797,6 @@ def pivot( ), ) - @df_api_usage def unpivot( self, value_column: str, name_column: str, column_list: List[ColumnOrName] ) -> "DataFrame": @@ -1858,7 +1841,6 @@ def unpivot( ) return self._with_plan(unpivot_plan) - @df_api_usage def limit(self, n: int, offset: int = 0) -> "DataFrame": """Returns a new DataFrame that contains at most ``n`` rows from the current DataFrame, skipping ``offset`` rows from the beginning (similar to LIMIT and OFFSET in SQL). @@ -1891,7 +1873,6 @@ def limit(self, n: int, offset: int = 0) -> "DataFrame": return self._with_plan(self._select_statement.limit(n, offset=offset)) return self._with_plan(Limit(Literal(n), Literal(offset), self._plan)) - @df_api_usage def union(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (``other``), excluding any duplicate rows. Both input @@ -1925,7 +1906,6 @@ def union(self, other: "DataFrame") -> "DataFrame": ) return self._with_plan(UnionPlan(self._plan, other._plan, is_all=False)) - @df_api_usage def union_all(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (``other``), including any duplicate rows. Both input @@ -1961,7 +1941,6 @@ def union_all(self, other: "DataFrame") -> "DataFrame": ) return self._with_plan(UnionPlan(self._plan, other._plan, is_all=True)) - @df_api_usage def union_by_name(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (``other``), excluding any duplicate rows. @@ -1987,7 +1966,6 @@ def union_by_name(self, other: "DataFrame") -> "DataFrame": """ return self._union_by_name_internal(other, is_all=False) - @df_api_usage def union_all_by_name(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (``other``), including any duplicate rows. @@ -2061,7 +2039,6 @@ def _union_by_name_internal( df = self._with_plan(UnionPlan(self._plan, right_child._plan, is_all)) return df - @df_api_usage def intersect(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains the intersection of rows from the current DataFrame and another DataFrame (``other``). Duplicate rows are @@ -2095,7 +2072,6 @@ def intersect(self, other: "DataFrame") -> "DataFrame": ) return self._with_plan(Intersect(self._plan, other._plan)) - @df_api_usage def except_(self, other: "DataFrame") -> "DataFrame": """Returns a new DataFrame that contains all the rows from the current DataFrame except for the rows that also appear in the ``other`` DataFrame. Duplicate rows are eliminated. @@ -2129,7 +2105,6 @@ def except_(self, other: "DataFrame") -> "DataFrame": ) return self._with_plan(Except(self._plan, other._plan)) - @df_api_usage def natural_join( self, right: "DataFrame", how: Optional[str] = None, **kwargs ) -> "DataFrame": @@ -2192,7 +2167,6 @@ def natural_join( return self._with_plan(select_plan) return self._with_plan(join_plan) - @df_api_usage def join( self, right: "DataFrame", @@ -2526,7 +2500,6 @@ def join( raise TypeError("Invalid type for join. Must be Dataframe") - @df_api_usage def join_table_function( self, func: Union[str, List[str], TableFunctionCall], @@ -2670,7 +2643,6 @@ def join_table_function( TableFunctionJoin(self._plan, func_expr, right_cols=new_col_names) ) - @df_api_usage def cross_join( self, right: "DataFrame", @@ -2784,14 +2756,23 @@ def _join_dataframes( match_condition._expression if match_condition is not None else None, ) if self._select_statement: - return self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - join_logical_plan, analyzer=self._session._analyzer - ), - analyzer=self._session._analyzer, - ) + analyzer = self._session._analyzer + select_statement_snowflake_plan = analyzer.create_select_snowflake_plan( + join_logical_plan, analyzer=analyzer + ) + select_statement_logic_plan = analyzer.create_select_statement( + from_=select_statement_snowflake_plan, analyzer=analyzer ) + new_df = self._with_plan(select_statement_logic_plan) + return new_df + # return self._with_plan( + # self._session._analyzer.create_select_statement( + # from_=self._session._analyzer.create_select_snowflake_plan( + # join_logical_plan, analyzer=self._session._analyzer + # ), + # analyzer=self._session._analyzer, + # ) + # ) return self._with_plan(join_logical_plan) def _join_dataframes_internal( @@ -2830,7 +2811,6 @@ def _join_dataframes_internal( ) return self._with_plan(join_logical_plan) - @df_api_usage def with_column( self, col_name: str, col: Union[Column, TableFunctionCall] ) -> "DataFrame": @@ -2876,7 +2856,6 @@ def with_column( """ return self.with_columns([col_name], [col]) - @df_api_usage def with_columns( self, col_names: List[str], values: List[Union[Column, TableFunctionCall]] ) -> "DataFrame": @@ -3030,7 +3009,6 @@ def write(self) -> DataFrameWriter: return self._writer - @df_collect_api_telemetry def copy_into_table( self, table_name: Union[str, Iterable[str]], @@ -3195,7 +3173,6 @@ def copy_into_table( ), )._internal_collect_with_tag_no_telemetry(statement_params=statement_params) - @df_collect_api_telemetry def show( self, n: int = 10, @@ -3231,7 +3208,6 @@ def show( extra_warning_text="Use `DataFrame.join_table_function()` instead.", extra_doc_string="Use :meth:`join_table_function` instead.", ) - @df_api_usage def flatten( self, input: ColumnOrName, @@ -3405,7 +3381,6 @@ def row_to_string(row: List[str]) -> str: + line ) - @df_collect_api_telemetry def create_or_replace_view( self, name: Union[str, Iterable[str]], @@ -3448,7 +3423,6 @@ def create_or_replace_view( ), ) - @df_collect_api_telemetry def create_or_replace_dynamic_table( self, name: Union[str, Iterable[str]], @@ -3541,7 +3515,6 @@ def create_or_replace_dynamic_table( ), ) - @df_collect_api_telemetry def create_or_replace_temp_view( self, name: Union[str, Iterable[str]], @@ -3715,7 +3688,6 @@ def first( take = first - @df_api_usage def sample( self, frac: Optional[float] = None, n: Optional[int] = None ) -> "DataFrame": @@ -3860,7 +3832,6 @@ def describe(self, *cols: Union[str, List[str]]) -> "DataFrame": ) return res_df - @df_api_usage def rename( self, col_or_mapper: Union[ColumnOrName, dict], @@ -3932,7 +3903,6 @@ def rename( return self._with_plan(rename_plan) - @df_api_usage def with_column_renamed(self, existing: ColumnOrName, new: str) -> "DataFrame": """Returns a DataFrame with the specified column ``existing`` renamed as ``new``. @@ -3996,7 +3966,6 @@ def with_column_renamed(self, existing: ColumnOrName, new: str) -> "DataFrame": ] return self.select(new_columns) - @df_collect_api_telemetry def cache_result( self, *, statement_params: Optional[Dict[str, str]] = None ) -> "Table": @@ -4094,7 +4063,6 @@ def cache_result( cached_df.is_cached = True return cached_df - @df_collect_api_telemetry def random_split( self, weights: List[float], From 1b6666534644a7cb02c2efb8a41d5727d92ac68c Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 11 Nov 2024 23:35:13 -0800 Subject: [PATCH 2/3] test --- .../snowpark/_internal/analyzer/analyzer.py | 74 ++++++++++++++++++- .../snowpark/_internal/analyzer/expression.py | 17 ++++- .../_internal/analyzer/snowflake_plan.py | 39 +++++++++- src/snowflake/snowpark/dataframe.py | 54 +++++++++++--- 4 files changed, 166 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 9c95493b1b4..332e8a8eae7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -167,6 +167,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.generated_alias_maps = {} self.subquery_plans = [] self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None + self.conflicted_maps_to_use = None def analyze( self, @@ -366,8 +367,32 @@ def analyze( return expr.sql if isinstance(expr, Attribute): + # TODO: which plan expr_to_alias to use? + expr_plan = expr.plan + + def traverse_child(root): + if root == expr_plan: + return 0 + for child in root.children_plan_nodes: + res = traverse_child(child) + if res == -1: + continue + return res + 1 + return -1 + assert self.alias_maps_to_use is not None - name = self.alias_maps_to_use.get(expr.expr_id, expr.name) + name = self.alias_maps_to_use.get(expr.expr_id) + if not name: + if expr.expr_id not in self.conflicted_maps_to_use: + name = expr.name + else: + # find the plan that has the least depth + min_depth = float("inf") + for tmp_name, plan in self.conflicted_maps_to_use[expr.expr_id]: + depth = traverse_child(plan) + if depth < min_depth: + min_depth = depth + name = tmp_name return quote_name(name) if isinstance(expr, UnresolvedAttribute): @@ -675,6 +700,7 @@ def binary_operator_extractor( df_aliased_col_name_to_real_col_name, parse_local_name=False, ) -> str: + if self.session.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name @@ -802,12 +828,16 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: if isinstance(logical_plan, Selectable): # Selectable doesn't have children. It already has the expr_to_alias dict. - self.alias_maps_to_use = logical_plan.expr_to_alias.copy() + self.alias_maps_to_use = ( + logical_plan.expr_to_alias.copy() + ) # logical_plan.expr_to_alias.copy() else: + if isinstance(logical_plan, Join): + print("break point") + use_maps = {} # get counts of expr_to_alias keys counts = Counter() - # TODO join: key here? we are only keeping non-shared expr_to_alias keys for v in resolved_children.values(): if v.expr_to_alias: counts.update(list(v.expr_to_alias.keys())) @@ -820,6 +850,44 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: {p: q for p, q in v.expr_to_alias.items() if counts[p] < 2} ) + def find_diff_values(*dicts): + # Store values by key for comparison + value_map = defaultdict(set) + # Store duplicate keys with differing values + diff_keys = {} + + # Gather all values for each key from all dictionaries + for d in dicts: + for key, value in d.items(): + value_map[key].add(value) + + # Identify keys with differing values + for key, values in value_map.items(): + if ( + len(values) > 1 + ): # Only consider keys with more than one unique value + diff_keys[key] = list(values) + + return set(diff_keys.keys()) + + dup_keys_with_different_values = find_diff_values( + *[v.expr_to_alias for v in resolved_children.values()] + ) + self.conflicted_maps_to_use = defaultdict(list) + # if the logic plan has a conflict alias map already and there is no children + if ( + getattr(logical_plan, "conflicted_alias_map", {}) + and not resolved_children + ): + self.conflicted_maps_to_use = logical_plan.conflicted_alias_map + else: + for plan in resolved_children.values(): + if plan.conflicted_alias_map: + self.conflicted_maps_to_use.update(plan.conflicted_alias_map) + for k, v in plan.expr_to_alias.items(): + if k in dup_keys_with_different_values: + self.conflicted_maps_to_use[k].append((v, plan)) + self.alias_maps_to_use = use_maps res = self.do_resolve_with_resolved_children( diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a7cb5fd97a9..0d80df8457b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -149,11 +149,17 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): class NamedExpression: name: str _expr_id: Optional[uuid.UUID] = None + id = 0 + + @staticmethod + def get_next_id(): + NamedExpression.id += 1 + return NamedExpression.id @property def expr_id(self) -> uuid.UUID: if not self._expr_id: - self._expr_id = uuid.uuid4() + self._expr_id = NamedExpression.get_next_id() return self._expr_id def __copy__(self): @@ -222,20 +228,25 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: class Attribute(Expression, NamedExpression): - def __init__(self, name: str, datatype: DataType, nullable: bool = True) -> None: + def __init__( + self, name: str, datatype: DataType, nullable: bool = True, plan=None + ) -> None: super().__init__() self.name = name self.datatype: DataType = datatype self.nullable = nullable + self.plan = plan - def with_name(self, new_name: str) -> "Attribute": + def with_name(self, new_name: str, plan) -> "Attribute": if self.name == new_name: + self.plan = plan return self else: return Attribute( snowflake.snowpark._internal.utils.quote_name(new_name), self.datatype, self.nullable, + plan=plan, ) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 860907ba616..02d21862cd4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -225,12 +225,18 @@ def __init__( referenced_ctes: Optional[Dict[WithQueryBlock, int]] = None, *, session: "snowflake.snowpark.session.Session", + conflicted_alias_map=None, ) -> None: super().__init__() self.queries = queries self.schema_query = schema_query self.post_actions = post_actions if post_actions else [] - self.expr_to_alias = expr_to_alias if expr_to_alias else {} + self.expr_to_alias = ( + expr_to_alias if expr_to_alias else {} + ) # use to stored old alias map info, TODO: when use old one and when to use new one? + self.conflicted_alias_map = ( + conflicted_alias_map if conflicted_alias_map else {} + ) # use to store conflicted alias map info self.session = session self.source_plan = source_plan self.is_ddl_on_temp_object = is_ddl_on_temp_object @@ -503,7 +509,23 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 return copied_plan def add_aliases(self, to_add: Dict) -> None: + # intersect_keys = set(self.expr_to_alias.keys()).intersection(to_add.keys()) + conflicted = False + intersected_keys = set(to_add.keys()).intersection( + set(self.expr_to_alias.keys()) + ) + for key in intersected_keys: + if to_add[key] != self.expr_to_alias[key]: + conflicted = True + print( + f"Alias conflict for key {key}, to_add: {to_add[key]} and self.expr_to_alias: {self.expr_to_alias[key]}" + ) + + if conflicted: + print(f"before conflict, self.expr_to_alias: {self.expr_to_alias}") self.expr_to_alias = {**self.expr_to_alias, **to_add} + if conflicted: + print(f"after conflict, self.expr_to_alias: {self.expr_to_alias}") class SnowflakePlanBuilder: @@ -596,6 +618,17 @@ def build_binary( }.items() if k not in common_columns } + conflicted_alias_map = defaultdict(list) + for k in common_columns: + if select_left.expr_to_alias[k] == select_right.expr_to_alias[k]: + new_expr_to_alias[k] = select_left.expr_to_alias[k] + else: + conflicted_alias_map[k].append( + (select_left.expr_to_alias[k], select_left) + ) + conflicted_alias_map[k].append( + (select_right.expr_to_alias[k], select_right) + ) api_calls = [*select_left.api_calls, *select_right.api_calls] # Need to do a deduplication to avoid repeated query. @@ -643,7 +676,11 @@ def build_binary( api_calls=api_calls, session=self.session, referenced_ctes=referenced_ctes, + conflicted_alias_map=conflicted_alias_map, ) + # ret_plan.conflicted_alias_map = conflicted_alias_map + + # return ret_plan def query( self, diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 7095a4cf1b4..4840ec5fd10 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -186,8 +186,18 @@ _UNALIASED_REGEX = re.compile(f"""._[a-zA-Z0-9]{{{_NUM_PREFIX_DIGITS}}}_(.*)""") +distinct_id = 0 + + +def _get_next_distinct_id(): + global distinct_id + distinct_id += 1 + return distinct_id + + def _generate_prefix(prefix: str) -> str: - return f"{prefix}_{generate_random_alphanumeric(_NUM_PREFIX_DIGITS)}_" + return f"{prefix}_{_get_next_distinct_id()}_" + # return f"{prefix}_{generate_random_alphanumeric(_NUM_PREFIX_DIGITS)}_" def _get_unaliased(col_name: str) -> List[str]: @@ -275,12 +285,16 @@ def _disambiguate( ] ) + assert lhs_remapped._plan.children_plan_nodes[0] == lhs._plan + rhs_remapped = rhs.select( [ _alias_if_needed(rhs, name, rhs_prefix, rsuffix, common_col_names) for name in rhs_names ] ) + + assert rhs_remapped._plan.children_plan_nodes[0] == rhs._plan return lhs_remapped, rhs_remapped @@ -1205,7 +1219,11 @@ def select( analyzer=self._session._analyzer, ).select(names) ) - return self._with_plan(self._select_statement.select(names)) + + new_select = self._select_statement.select(names) + new_df = self._with_plan(new_select) + return new_df + # return self._with_plan(self._select_statement.select(names)) return self._with_plan(Project(names, join_plan or self._plan)) @@ -2789,6 +2807,7 @@ def _join_dataframes_internal( self, right, join_type, [], lsuffix=lsuffix, rsuffix=rsuffix ) join_condition_expr = join_exprs._expression if join_exprs is not None else None + join_condition_expr.extra_information = (lhs._plan, rhs._plan) match_condition_expr = ( match_condition._expression if match_condition is not None else None ) @@ -2800,15 +2819,28 @@ def _join_dataframes_internal( match_condition_expr, ) if self._select_statement: - return self._with_plan( - self._session._analyzer.create_select_statement( - from_=self._session._analyzer.create_select_snowflake_plan( - join_logical_plan, - analyzer=self._session._analyzer, - ), - analyzer=self._session._analyzer, - ) + analyzer = self._session._analyzer + select_snowflake_plan = analyzer.create_select_snowflake_plan( + join_logical_plan, + analyzer=analyzer, + ) + select_statement = analyzer.create_select_statement( + from_=select_snowflake_plan, + analyzer=analyzer, ) + new_df = self._with_plan(select_statement) + return new_df + + # if self._select_statement: + # return self._with_plan( + # self._session._analyzer.create_select_statement( + # from_=self._session._analyzer.create_select_snowflake_plan( + # join_logical_plan, + # analyzer=self._session._analyzer, + # ), + # analyzer=self._session._analyzer, + # ) + # ) return self._with_plan(join_logical_plan) def with_column( @@ -4179,7 +4211,7 @@ def _resolve(self, col_name: str) -> Union[Expression, NamedExpression]: normalized_col_name = quote_name(col_name) cols = list(filter(lambda attr: attr.name == normalized_col_name, self._output)) if len(cols) == 1: - return cols[0].with_name(normalized_col_name) + return cols[0].with_name(normalized_col_name, plan=self._plan) else: raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME( col_name From 2297bf392a2857b223dc6a695511febbece1e714 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 13 Nov 2024 09:27:50 -0800 Subject: [PATCH 3/3] v0 working with join issue --- .../snowpark/_internal/analyzer/analyzer.py | 33 ++++++++---- .../_internal/analyzer/select_statement.py | 2 + .../snowpark/_internal/compiler/utils.py | 2 +- src/snowflake/snowpark/dataframe.py | 52 ++++++++++++------- src/snowflake/snowpark/session.py | 37 +++++++++---- 5 files changed, 86 insertions(+), 40 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 332e8a8eae7..39891cee585 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -367,12 +367,19 @@ def analyze( return expr.sql if isinstance(expr, Attribute): - # TODO: which plan expr_to_alias to use? expr_plan = expr.plan + # for sql simplifier case + expr_plan_source_plan = None + expr_plan_from = None + if expr_plan and isinstance(expr_plan.source_plan, SelectStatement): + expr_plan_source_plan = expr_plan.source_plan + expr_plan_from = expr_plan_source_plan.from_ + def traverse_child(root): if root == expr_plan: return 0 + # this does not work for sql simplifier the case because selectable has no children nodes for child in root.children_plan_nodes: res = traverse_child(child) if res == -1: @@ -386,13 +393,22 @@ def traverse_child(root): if expr.expr_id not in self.conflicted_maps_to_use: name = expr.name else: - # find the plan that has the least depth + # find the plan that has the least depth for non-sql simplifier case min_depth = float("inf") for tmp_name, plan in self.conflicted_maps_to_use[expr.expr_id]: - depth = traverse_child(plan) - if depth < min_depth: - min_depth = depth - name = tmp_name + if expr_plan_from: + # sql simplifier case, we just need to compare the from_ case + if expr_plan_from == plan.source_plan.from_: + name = tmp_name + break + else: + # non sql simplifier case + depth = traverse_child(plan) + if depth < min_depth and depth != -1: + min_depth = depth + name = tmp_name + if not name: + raise RuntimeError("alias is not found") return quote_name(name) if isinstance(expr, UnresolvedAttribute): @@ -828,9 +844,7 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: if isinstance(logical_plan, Selectable): # Selectable doesn't have children. It already has the expr_to_alias dict. - self.alias_maps_to_use = ( - logical_plan.expr_to_alias.copy() - ) # logical_plan.expr_to_alias.copy() + self.alias_maps_to_use = logical_plan.expr_to_alias.copy() else: if isinstance(logical_plan, Join): print("break point") @@ -893,6 +907,7 @@ def find_diff_values(*dicts): res = self.do_resolve_with_resolved_children( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name ) + res.conflicted_alias_map.update(self.conflicted_maps_to_use) res.df_aliased_col_name_to_real_col_name.update( df_aliased_col_name_to_real_col_name ) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 4e9d97c8a9c..2427a7d744c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1113,6 +1113,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": # there is no need to flatten the projection complexity since the child # select projection is already flattened with the current select. new._merge_projection_complexity_with_subquery = False + assert self.from_ == new.from_ else: new = SelectStatement( projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer @@ -1124,6 +1125,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": self, ) ) + assert self.from_ == new.from_.from_ new.flatten_disabled = disable_next_level_flatten assert new.projection is not None diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index eae9376c0b4..b0b98475a5f 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -405,7 +405,7 @@ def get_name(node: Optional[LogicalPlan]) -> str: sql_size = len(sql_text) sql_preview = sql_text[:50] - return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}" + return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}, {node.expr_to_alias}" g = graphviz.Graph(format="png") diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 4840ec5fd10..c23472e94e9 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -272,29 +272,36 @@ def _disambiguate( lhs_prefix = _generate_prefix("l") if not suffix_provided else "" rhs_prefix = _generate_prefix("r") if not suffix_provided else "" - lhs_remapped = lhs.select( - [ - _alias_if_needed( - lhs, - name, - lhs_prefix, - lsuffix, - [] if isinstance(join_type, (LeftSemi, LeftAnti)) else common_col_names, - ) - for name in lhs_names - ] - ) + aliased_cols = [ + _alias_if_needed( + lhs, + name, + lhs_prefix, + lsuffix, + [] if isinstance(join_type, (LeftSemi, LeftAnti)) else common_col_names, + ) + for name in lhs_names + ] + lhs_remapped = lhs.select(aliased_cols) - assert lhs_remapped._plan.children_plan_nodes[0] == lhs._plan + if lhs.session.sql_simplifier_enabled: + assert lhs_remapped._select_statement.from_ == lhs._select_statement.from_ + else: + assert lhs_remapped._plan.children_plan_nodes[0] == lhs._plan - rhs_remapped = rhs.select( - [ - _alias_if_needed(rhs, name, rhs_prefix, rsuffix, common_col_names) - for name in rhs_names - ] - ) + aliased_cols = [ + _alias_if_needed(rhs, name, rhs_prefix, rsuffix, common_col_names) + for name in rhs_names + ] + rhs_remapped = rhs.select(aliased_cols) - assert rhs_remapped._plan.children_plan_nodes[0] == rhs._plan + if rhs.session.sql_simplifier_enabled: + # case 1 can be flattened, case 2 can't be flattened + assert rhs_remapped._select_statement.from_ == rhs._select_statement.from_ or ( + rhs_remapped._select_statement.from_.from_ == rhs._select_statement.from_ + ) + else: + assert rhs_remapped._plan.children_plan_nodes[0] == rhs._plan return lhs_remapped, rhs_remapped @@ -538,6 +545,7 @@ def __init__( is_cached: bool = False, ) -> None: self._session = session + # SelectStatement will create a new SnowflakePlan during resolve if there is not one bound with that yet self._plan = self._session._analyzer.resolve(plan) if isinstance(plan, (SelectStatement, MockSelectStatement)): self._select_statement = plan @@ -1221,6 +1229,10 @@ def select( ) new_select = self._select_statement.select(names) + # case 1 new select is flatten, case 2 new select is not flattened + assert (self._select_statement.from_ == new_select.from_) or ( + self._select_statement.from_ == new_select.from_.from_ + ) new_df = self._with_plan(new_select) return new_df # return self._with_plan(self._select_statement.select(names)) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index fb1c8dc5469..2308edf0254 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3032,16 +3032,33 @@ def convert_row_to_list( project_columns.append(column(name)) if self.sql_simplifier_enabled: - df = DataFrame( - self, - self._analyzer.create_select_statement( - from_=self._analyzer.create_select_snowflake_plan( - SnowflakeValues(attrs, converted, schema_query=schema_query), - analyzer=self._analyzer, - ), - analyzer=self._analyzer, - ), - ).select(project_columns) + analyzer = self._analyzer + snowflakevalues = SnowflakeValues( + attrs, converted, schema_query=schema_query + ) + select_snowflake_plan = analyzer.create_select_snowflake_plan( + snowflake_plan=snowflakevalues, + analyzer=analyzer, + ) + select_statement = analyzer.create_select_statement( + from_=select_snowflake_plan, analyzer=analyzer + ) + # select_statement.snowflake_plan # this will create a new snowflake plan + df = DataFrame(self, select_statement) + assert df._plan == select_statement.snowflake_plan + df.select(project_columns) + assert df._select_statement == select_statement + + # df = DataFrame( + # self, + # self._analyzer.create_select_statement( + # from_=self._analyzer.create_select_snowflake_plan( + # sv, + # analyzer=self._analyzer, + # ), + # analyzer=self._analyzer, + # ), + # ).select(project_columns) else: df = DataFrame( self, SnowflakeValues(attrs, converted, schema_query=schema_query)