Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1748403: Use Aggregate.aggregate_expressions to infer quoted identifiers #2526

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def infer_metadata(
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
Aggregate,
Filter,
Project,
Sample,
Expand All @@ -97,6 +98,13 @@ def infer_metadata(
# When source_plan is a SnowflakeValues, metadata is already defined locally
elif isinstance(source_plan, SnowflakeValues):
attributes = source_plan.output
# When source_plan is Aggregate or Project, we already have quoted_identifiers
elif isinstance(source_plan, Aggregate):
quoted_identifiers = infer_quoted_identifiers_from_expressions(
source_plan.aggregate_expressions, # type: ignore
analyzer,
df_aliased_col_name_to_real_col_name,
)
elif isinstance(source_plan, Project):
quoted_identifiers = infer_quoted_identifiers_from_expressions(
source_plan.project_list, # type: ignore
Expand Down
37 changes: 36 additions & 1 deletion tests/integ/test_reduce_describe_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
TempObjectType,
random_name_for_temp_object,
)
from snowflake.snowpark.functions import col, count, lit, seq2, table_function
from snowflake.snowpark.functions import (
avg,
col,
count,
lit,
max as max_,
min as min_,
seq2,
table_function,
)
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED,
Session,
Expand Down Expand Up @@ -157,6 +166,15 @@ def setup(request, session):
),
]

agg_df_ops_expected_quoted_identifiers = [
(lambda df: df.agg(avg("a").as_("a"), count("b")), ['"A"', '"COUNT(B)"']),
(lambda df: df.agg(avg("a").as_('"a"'), count("b")).select('"a"'), ['"a"']),
(lambda df: df.group_by("a").agg(avg("b")), ['"A"', '"AVG(B)"']),
(lambda df: df.rollup("a").agg(min_("b")), ['"A"', '"MIN(B)"']),
(lambda df: df.cube("a").agg(max_("b")), ['"A"', '"MAX(B)"']),
(lambda df: df.distinct(), ['"A"', '"B"']),
]


def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute]) -> None:
for attr1, attr2 in zip(attrs1, attrs2):
Expand Down Expand Up @@ -261,6 +279,23 @@ def test_snowflake_values(session):
assert df._plan.quoted_identifiers == expected_quoted_identifiers


@pytest.mark.parametrize(
"action,expected_quoted_identifiers",
agg_df_ops_expected_quoted_identifiers,
)
def test_aggregate(session, action, expected_quoted_identifiers):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df = action(df)
if session.reduce_describe_query_enabled:
with SqlCounter(query_count=0, describe_count=0):
assert df._plan._metadata.quoted_identifiers == expected_quoted_identifiers
assert df._plan.quoted_identifiers == expected_quoted_identifiers
else:
with SqlCounter(query_count=0, describe_count=1):
assert df._plan._metadata.quoted_identifiers is None
assert df._plan.quoted_identifiers == expected_quoted_identifiers


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP")
def test_reduce_describe_query_enabled_on_session(db_parameters):
with Session.builder.configs(db_parameters).create() as new_session:
Expand Down
Loading