Skip to content

Commit

Permalink
SNOW-1438001: Add support for list values in Series.str.len (#2594)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1438001

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [ ] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

   Add support for list values in Series.str.len.
  • Loading branch information
sfc-gh-helmeleegy authored Nov 12, 2024
1 parent ebc749d commit 4c1be6a
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
- Added support for `pd.read_html` (Uses native pandas for processing).
- Added support for `pd.read_xml` (Uses native pandas for processing).
- Added support for aggregation functions `"size"` and `len` in `GroupBy.aggregate`, `DataFrame.aggregate`, and `Series.aggregate`.
- Added support for list values in `Series.str.len`.

#### Bug Fixes

Expand Down
3 changes: 2 additions & 1 deletion docs/source/modin/supported/series_str_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ the method in the left column.
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``join`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``len`` | Y | |
| ``len`` | P | Only string and list data values are supported. |
| | | All column values must be of the same type. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``ljust`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from snowflake.snowpark.functions import (
abs as abs_,
array_construct,
array_size,
bround,
builtin,
cast,
Expand Down Expand Up @@ -16553,12 +16554,23 @@ def str_len(self, **kwargs: Any) -> "SnowflakeQueryCompiler":
-------
SnowflakeQueryCompiler representing result of the string operation.
"""
# TODO SNOW-1438001: Handle dict, list, and tuple values for Series.str.len().
return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
lambda col: self._replace_non_str(col, length(col))
# TODO SNOW-1438001: Handle dict, and tuple values for Series.str.len().
col = self._modin_frame.data_column_snowflake_quoted_identifiers[0]
if isinstance(
self._modin_frame.quoted_identifier_to_snowflake_type([col]).get(col),
ArrayType,
):
return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
lambda col: array_size(col)
)
)
else:
return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
lambda col: self._replace_non_str(col, length(col))
)
)
)

def str_ljust(self, width: int, fillchar: str = " ") -> None:
ErrorMessage.method_not_implemented_error("ljust", "Series.str")
Expand Down
52 changes: 51 additions & 1 deletion tests/integ/modin/series/test_str_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import pandas as native_pd
import pytest

from snowflake.snowpark._internal.utils import TempObjectType
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import eval_snowpark_pandas_result
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result
from tests.integ.utils.sql_counter import sql_count_checker

TEST_DATA = [
Expand Down Expand Up @@ -449,6 +450,55 @@ def test_str_len():
eval_snowpark_pandas_result(snow_ser, native_ser, lambda ser: ser.str.len())


@sql_count_checker(query_count=1)
def test_str_len_list():
native_ser = native_pd.Series([["a", "b"], ["c", "d", None], None, []])
snow_ser = pd.Series(native_ser)
eval_snowpark_pandas_result(snow_ser, native_ser, lambda ser: ser.str.len())


@sql_count_checker(query_count=9, udf_count=1)
def test_str_len_list_coin_base(session):
from tests.utils import Utils

table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Utils.create_table(
session, table_name, "SHARED_CARD_USERS array", is_temporary=True
)
session.sql(
f"""insert into {table_name} (SHARED_CARD_USERS) SELECT PARSE_JSON('["Apple", "Pear", "Cabbage"]')"""
).collect()
session.sql(f"insert into {table_name} values (NULL)").collect()

df = pd.read_snowflake(table_name)

def compute_num_shared_card_users(x):
"""
Helper function to compute the number of shared card users
Input:
- x: the array with the users
Output: Number of shared card users
"""
if x:
return len(x)
else:
return 0

# The following two methods for computing the final result should be identical.

# The first one uses `Series.str.len` followed by `Series.fillna`.
str_len_res = df["SHARED_CARD_USERS"].str.len().fillna(0)

# The second one uses `Series.apply` and a user defined function.
apply_res = df["SHARED_CARD_USERS"].apply(
lambda x: compute_num_shared_card_users(x)
)

assert_series_equal(str_len_res, apply_res, check_dtype=False)


@pytest.mark.parametrize(
"items",
[
Expand Down

0 comments on commit 4c1be6a

Please sign in to comment.