Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aling/dimond shaped join #2610

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self.generated_alias_maps = {}
self.subquery_plans = []
self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None
self.conflicted_maps_to_use = None

def analyze(
self,
Expand Down Expand Up @@ -366,8 +367,48 @@ def analyze(
return expr.sql

if isinstance(expr, Attribute):
expr_plan = expr.plan

# for sql simplifier case
expr_plan_source_plan = None
expr_plan_from = None
if expr_plan and isinstance(expr_plan.source_plan, SelectStatement):
expr_plan_source_plan = expr_plan.source_plan
expr_plan_from = expr_plan_source_plan.from_

def traverse_child(root):
if root == expr_plan:
return 0
# this does not work for sql simplifier the case because selectable has no children nodes
for child in root.children_plan_nodes:
res = traverse_child(child)
if res == -1:
continue
return res + 1
return -1

assert self.alias_maps_to_use is not None
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
name = self.alias_maps_to_use.get(expr.expr_id)
if not name:
if expr.expr_id not in self.conflicted_maps_to_use:
name = expr.name
else:
# find the plan that has the least depth for non-sql simplifier case
min_depth = float("inf")
for tmp_name, plan in self.conflicted_maps_to_use[expr.expr_id]:
if expr_plan_from:
# sql simplifier case, we just need to compare the from_ case
if expr_plan_from == plan.source_plan.from_:
name = tmp_name
break
else:
# non sql simplifier case
depth = traverse_child(plan)
if depth < min_depth and depth != -1:
min_depth = depth
name = tmp_name
if not name:
raise RuntimeError("alias is not found")
return quote_name(name)

if isinstance(expr, UnresolvedAttribute):
Expand Down Expand Up @@ -675,6 +716,7 @@ def binary_operator_extractor(
df_aliased_col_name_to_real_col_name,
parse_local_name=False,
) -> str:

if self.session.eliminate_numeric_sql_value_cast_enabled:
left_sql_expr = self.to_sql_try_avoid_cast(
expr.left, df_aliased_col_name_to_real_col_name, parse_local_name
Expand Down Expand Up @@ -804,6 +846,9 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
# Selectable doesn't have children. It already has the expr_to_alias dict.
self.alias_maps_to_use = logical_plan.expr_to_alias.copy()
else:
if isinstance(logical_plan, Join):
print("break point")

use_maps = {}
# get counts of expr_to_alias keys
counts = Counter()
Expand All @@ -819,11 +864,50 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
{p: q for p, q in v.expr_to_alias.items() if counts[p] < 2}
)

def find_diff_values(*dicts):
# Store values by key for comparison
value_map = defaultdict(set)
# Store duplicate keys with differing values
diff_keys = {}

# Gather all values for each key from all dictionaries
for d in dicts:
for key, value in d.items():
value_map[key].add(value)

# Identify keys with differing values
for key, values in value_map.items():
if (
len(values) > 1
): # Only consider keys with more than one unique value
diff_keys[key] = list(values)

return set(diff_keys.keys())

dup_keys_with_different_values = find_diff_values(
*[v.expr_to_alias for v in resolved_children.values()]
)
self.conflicted_maps_to_use = defaultdict(list)
# if the logic plan has a conflict alias map already and there is no children
if (
getattr(logical_plan, "conflicted_alias_map", {})
and not resolved_children
):
self.conflicted_maps_to_use = logical_plan.conflicted_alias_map
else:
for plan in resolved_children.values():
if plan.conflicted_alias_map:
self.conflicted_maps_to_use.update(plan.conflicted_alias_map)
for k, v in plan.expr_to_alias.items():
if k in dup_keys_with_different_values:
self.conflicted_maps_to_use[k].append((v, plan))

self.alias_maps_to_use = use_maps

res = self.do_resolve_with_resolved_children(
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
)
res.conflicted_alias_map.update(self.conflicted_maps_to_use)
res.df_aliased_col_name_to_real_col_name.update(
df_aliased_col_name_to_real_col_name
)
Expand Down
17 changes: 14 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,17 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
class NamedExpression:
name: str
_expr_id: Optional[uuid.UUID] = None
id = 0

@staticmethod
def get_next_id():
NamedExpression.id += 1
return NamedExpression.id

@property
def expr_id(self) -> uuid.UUID:
if not self._expr_id:
self._expr_id = uuid.uuid4()
self._expr_id = NamedExpression.get_next_id()
return self._expr_id

def __copy__(self):
Expand Down Expand Up @@ -222,20 +228,25 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:


class Attribute(Expression, NamedExpression):
def __init__(self, name: str, datatype: DataType, nullable: bool = True) -> None:
def __init__(
self, name: str, datatype: DataType, nullable: bool = True, plan=None
) -> None:
super().__init__()
self.name = name
self.datatype: DataType = datatype
self.nullable = nullable
self.plan = plan

def with_name(self, new_name: str) -> "Attribute":
def with_name(self, new_name: str, plan) -> "Attribute":
if self.name == new_name:
self.plan = plan
return self
else:
return Attribute(
snowflake.snowpark._internal.utils.quote_name(new_name),
self.datatype,
self.nullable,
plan=plan,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
# there is no need to flatten the projection complexity since the child
# select projection is already flattened with the current select.
new._merge_projection_complexity_with_subquery = False
assert self.from_ == new.from_
else:
new = SelectStatement(
projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer
Expand All @@ -1124,6 +1125,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
self,
)
)
assert self.from_ == new.from_.from_

new.flatten_disabled = disable_next_level_flatten
assert new.projection is not None
Expand Down
40 changes: 39 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,18 @@ def __init__(
referenced_ctes: Optional[Dict[WithQueryBlock, int]] = None,
*,
session: "snowflake.snowpark.session.Session",
conflicted_alias_map=None,
) -> None:
super().__init__()
self.queries = queries
self.schema_query = schema_query
self.post_actions = post_actions if post_actions else []
self.expr_to_alias = expr_to_alias if expr_to_alias else {}
self.expr_to_alias = (
expr_to_alias if expr_to_alias else {}
) # use to stored old alias map info, TODO: when use old one and when to use new one?
self.conflicted_alias_map = (
conflicted_alias_map if conflicted_alias_map else {}
) # use to store conflicted alias map info
self.session = session
self.source_plan = source_plan
self.is_ddl_on_temp_object = is_ddl_on_temp_object
Expand Down Expand Up @@ -503,7 +509,23 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
return copied_plan

def add_aliases(self, to_add: Dict) -> None:
# intersect_keys = set(self.expr_to_alias.keys()).intersection(to_add.keys())
conflicted = False
intersected_keys = set(to_add.keys()).intersection(
set(self.expr_to_alias.keys())
)
for key in intersected_keys:
if to_add[key] != self.expr_to_alias[key]:
conflicted = True
print(
f"Alias conflict for key {key}, to_add: {to_add[key]} and self.expr_to_alias: {self.expr_to_alias[key]}"
)

if conflicted:
print(f"before conflict, self.expr_to_alias: {self.expr_to_alias}")
self.expr_to_alias = {**self.expr_to_alias, **to_add}
if conflicted:
print(f"after conflict, self.expr_to_alias: {self.expr_to_alias}")


class SnowflakePlanBuilder:
Expand Down Expand Up @@ -587,6 +609,7 @@ def build_binary(
common_columns = set(select_left.expr_to_alias.keys()).intersection(
select_right.expr_to_alias.keys()
)
# TODO join: related to alias
new_expr_to_alias = {
k: v
for k, v in {
Expand All @@ -595,6 +618,17 @@ def build_binary(
}.items()
if k not in common_columns
}
conflicted_alias_map = defaultdict(list)
for k in common_columns:
if select_left.expr_to_alias[k] == select_right.expr_to_alias[k]:
new_expr_to_alias[k] = select_left.expr_to_alias[k]
else:
conflicted_alias_map[k].append(
(select_left.expr_to_alias[k], select_left)
)
conflicted_alias_map[k].append(
(select_right.expr_to_alias[k], select_right)
)
api_calls = [*select_left.api_calls, *select_right.api_calls]

# Need to do a deduplication to avoid repeated query.
Expand Down Expand Up @@ -642,7 +676,11 @@ def build_binary(
api_calls=api_calls,
session=self.session,
referenced_ctes=referenced_ctes,
conflicted_alias_map=conflicted_alias_map,
)
# ret_plan.conflicted_alias_map = conflicted_alias_map

# return ret_plan

def query(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def get_name(node: Optional[LogicalPlan]) -> str:
sql_size = len(sql_text)
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}, {node.expr_to_alias}"

g = graphviz.Graph(format="png")

Expand Down
Loading
Loading