From ffbbf7da9c69e4e2895e131536a15fe0f89c93be Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Mon, 28 Oct 2024 15:25:26 -0700 Subject: [PATCH 1/2] add --- .../_internal/analyzer/metadata_utils.py | 8 ++++ tests/integ/test_reduce_describe_query.py | 37 ++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 9d2f45d0770..6edbde2d6b3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -77,6 +77,7 @@ def infer_metadata( ) from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.unary_plan_node import ( + Aggregate, Filter, Project, Sample, @@ -97,6 +98,13 @@ def infer_metadata( # When source_plan is a SnowflakeValues, metadata is already defined locally elif isinstance(source_plan, SnowflakeValues): attributes = source_plan.output + # When source_plan is Aggregate or Project, we already have quoted_identifiers + elif isinstance(source_plan, Aggregate): + quoted_identifiers = infer_quoted_identifiers_from_expressions( + source_plan.aggregate_expressions, # type: ignore + analyzer, + df_aliased_col_name_to_real_col_name, + ) elif isinstance(source_plan, Project): quoted_identifiers = infer_quoted_identifiers_from_expressions( source_plan.project_list, # type: ignore diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index 18693488bb5..2cb761fd9a5 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -14,7 +14,16 @@ TempObjectType, random_name_for_temp_object, ) -from snowflake.snowpark.functions import col, count, lit, seq2, table_function +from snowflake.snowpark.functions import ( + avg, + col, + count, + lit, + max as max_, + min as min_, + seq2, + table_function, +) from snowflake.snowpark.session import ( _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, Session, @@ -157,6 +166,15 @@ def setup(request, session): ), ] +agg_df_ops_expected_quoted_identifiers = [ + (lambda df: df.agg(avg("a").as_("a"), count("b")), ['"A"', '"COUNT(B)"']), + (lambda df: df.agg(avg("a").as_("a"), count("b")).select("a"), ['"A"']), + (lambda df: df.group_by("a").agg(avg("b")), ['"A"', '"AVG(B)"']), + (lambda df: df.rollup("a").agg(min_("b")), ['"A"', '"MIN(B)"']), + (lambda df: df.cube("a").agg(max_("b")), ['"A"', '"MAX(B)"']), + (lambda df: df.distinct(), ['"A"', '"B"']), +] + def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute]) -> None: for attr1, attr2 in zip(attrs1, attrs2): @@ -261,6 +279,23 @@ def test_snowflake_values(session): assert df._plan.quoted_identifiers == expected_quoted_identifiers +@pytest.mark.parametrize( + "action,expected_quoted_identifiers", + agg_df_ops_expected_quoted_identifiers, +) +def test_aggregate(session, action, expected_quoted_identifiers): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df = action(df) + if session.reduce_describe_query_enabled: + with SqlCounter(query_count=0, describe_count=0): + assert df._plan._metadata.quoted_identifiers == expected_quoted_identifiers + assert df._plan.quoted_identifiers == expected_quoted_identifiers + else: + with SqlCounter(query_count=0, describe_count=1): + assert df._plan._metadata.quoted_identifiers is None + assert df._plan.quoted_identifiers == expected_quoted_identifiers + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP") def test_reduce_describe_query_enabled_on_session(db_parameters): with Session.builder.configs(db_parameters).create() as new_session: From 52c2f0d7285c0aa97c14bbdc4ec8126873395dd7 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Mon, 28 Oct 2024 15:29:43 -0700 Subject: [PATCH 2/2] update --- tests/integ/test_reduce_describe_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index 2cb761fd9a5..2b767de57da 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -168,7 +168,7 @@ def setup(request, session): agg_df_ops_expected_quoted_identifiers = [ (lambda df: df.agg(avg("a").as_("a"), count("b")), ['"A"', '"COUNT(B)"']), - (lambda df: df.agg(avg("a").as_("a"), count("b")).select("a"), ['"A"']), + (lambda df: df.agg(avg("a").as_('"a"'), count("b")).select('"a"'), ['"a"']), (lambda df: df.group_by("a").agg(avg("b")), ['"A"', '"AVG(B)"']), (lambda df: df.rollup("a").agg(min_("b")), ['"A"', '"MIN(B)"']), (lambda df: df.cube("a").agg(max_("b")), ['"A"', '"MAX(B)"']),