From 4c1be6ad26d2cdf872a31c5b61f0cea8debec817 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Mon, 11 Nov 2024 16:49:09 -0800 Subject: [PATCH] SNOW-1438001: Add support for list values in Series.str.len (#2594) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1438001 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Add support for list values in Series.str.len. --- CHANGELOG.md | 1 + .../modin/supported/series_str_supported.rst | 3 +- .../compiler/snowflake_query_compiler.py | 22 ++++++-- tests/integ/modin/series/test_str_accessor.py | 52 ++++++++++++++++++- 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34efb8bea39..47204d74d1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ - Added support for `pd.read_html` (Uses native pandas for processing). - Added support for `pd.read_xml` (Uses native pandas for processing). - Added support for aggregation functions `"size"` and `len` in `GroupBy.aggregate`, `DataFrame.aggregate`, and `Series.aggregate`. +- Added support for list values in `Series.str.len`. #### Bug Fixes diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index ff993344ba2..1ecc6f5f1d5 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -76,7 +76,8 @@ the method in the left column. +-----------------------------+---------------------------------+----------------------------------------------------+ | ``join`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``len`` | Y | | +| ``len`` | P | Only string and list data values are supported. | +| | | All column values must be of the same type. | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``ljust`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index c4df86cf793..bb5badae356 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -85,6 +85,7 @@ from snowflake.snowpark.functions import ( abs as abs_, array_construct, + array_size, bround, builtin, cast, @@ -16553,12 +16554,23 @@ def str_len(self, **kwargs: Any) -> "SnowflakeQueryCompiler": ------- SnowflakeQueryCompiler representing result of the string operation. """ - # TODO SNOW-1438001: Handle dict, list, and tuple values for Series.str.len(). - return SnowflakeQueryCompiler( - self._modin_frame.apply_snowpark_function_to_columns( - lambda col: self._replace_non_str(col, length(col)) + # TODO SNOW-1438001: Handle dict, and tuple values for Series.str.len(). + col = self._modin_frame.data_column_snowflake_quoted_identifiers[0] + if isinstance( + self._modin_frame.quoted_identifier_to_snowflake_type([col]).get(col), + ArrayType, + ): + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda col: array_size(col) + ) + ) + else: + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda col: self._replace_non_str(col, length(col)) + ) ) - ) def str_ljust(self, width: int, fillchar: str = " ") -> None: ErrorMessage.method_not_implemented_error("ljust", "Series.str") diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 20fb46dd81f..23947af463d 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -10,8 +10,9 @@ import pandas as native_pd import pytest +from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import eval_snowpark_pandas_result +from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result from tests.integ.utils.sql_counter import sql_count_checker TEST_DATA = [ @@ -449,6 +450,55 @@ def test_str_len(): eval_snowpark_pandas_result(snow_ser, native_ser, lambda ser: ser.str.len()) +@sql_count_checker(query_count=1) +def test_str_len_list(): + native_ser = native_pd.Series([["a", "b"], ["c", "d", None], None, []]) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result(snow_ser, native_ser, lambda ser: ser.str.len()) + + +@sql_count_checker(query_count=9, udf_count=1) +def test_str_len_list_coin_base(session): + from tests.utils import Utils + + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + Utils.create_table( + session, table_name, "SHARED_CARD_USERS array", is_temporary=True + ) + session.sql( + f"""insert into {table_name} (SHARED_CARD_USERS) SELECT PARSE_JSON('["Apple", "Pear", "Cabbage"]')""" + ).collect() + session.sql(f"insert into {table_name} values (NULL)").collect() + + df = pd.read_snowflake(table_name) + + def compute_num_shared_card_users(x): + """ + Helper function to compute the number of shared card users + + Input: + - x: the array with the users + + Output: Number of shared card users + """ + if x: + return len(x) + else: + return 0 + + # The following two methods for computing the final result should be identical. + + # The first one uses `Series.str.len` followed by `Series.fillna`. + str_len_res = df["SHARED_CARD_USERS"].str.len().fillna(0) + + # The second one uses `Series.apply` and a user defined function. + apply_res = df["SHARED_CARD_USERS"].apply( + lambda x: compute_num_shared_card_users(x) + ) + + assert_series_equal(str_len_res, apply_res, check_dtype=False) + + @pytest.mark.parametrize( "items", [