Skip to content

Commit a5c807b

Browse files
authored
SNOW-1480718: Support Series.str.translate (#1776)
1 parent 9372edf commit a5c807b

File tree

7 files changed

+265
-5
lines changed

7 files changed

+265
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
- Added distributed tracing using open telemetry APIs for table stored procedure function in `DataFrame`:
1010
- _execute_and_get_query_id
1111

12+
### Snowpark pandas API Updates
13+
14+
#### New Features
15+
- Added partial support for `Series.str.translate` where the values in the `table` are single-codepoint strings.
16+
1217
## 1.19.0 (2024-06-25)
1318

1419
### Snowpark Python API Updates

docs/source/modin/series.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,5 @@ Series
285285
Series.str.split
286286
Series.str.startswith
287287
Series.str.strip
288+
Series.str.translate
288289
Series.str.upper

docs/source/modin/supported/series_str_supported.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ the method in the left column.
129129
+-----------------------------+---------------------------------+----------------------------------------------------+
130130
| ``title`` | Y | |
131131
+-----------------------------+---------------------------------+----------------------------------------------------+
132-
| ``translate`` | N | |
132+
| ``translate`` | P | ``N`` if any value in `table` has multiple |
133+
| | | characters. |
133134
+-----------------------------+---------------------------------+----------------------------------------------------+
134135
| ``upper`` | Y | |
135136
+-----------------------------+---------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
timestamp_ntz_from_parts,
126126
to_date,
127127
to_variant,
128+
translate,
128129
trim,
129130
uniform,
130131
upper,
@@ -13979,8 +13980,66 @@ def str_rstrip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompil
1397913980
def str_swapcase(self) -> None:
1398013981
ErrorMessage.method_not_implemented_error("swapcase", "Series.str")
1398113982

13982-
def str_translate(self, table: dict) -> None:
13983-
ErrorMessage.method_not_implemented_error("translate", "Series.str")
13983+
def str_translate(self, table: dict) -> "SnowflakeQueryCompiler":
13984+
"""
13985+
Map all characters in the string through the given mapping table.
13986+
13987+
Equivalent to standard :meth:`str.translate`.
13988+
13989+
Parameters
13990+
----------
13991+
table : dict
13992+
Table is a mapping of Unicode ordinals to Unicode ordinals, strings, or
13993+
None. Unmapped characters are left untouched.
13994+
Characters mapped to None are deleted. :meth:`str.maketrans` is a
13995+
helper function for making translation tables.
13996+
13997+
Returns
13998+
-------
13999+
SnowflakeQueryCompiler representing results of the string operation.
14000+
"""
14001+
# Snowflake SQL TRANSLATE:
14002+
# TRANSLATE(<subject>, <sourceAlphabet>, <targetAlphabet>)
14003+
# Characters in the <sourceAlphabet> string are mapped to the corresponding entry in <targetAlphabet>.
14004+
# If <sourceAlphabet> is longer than <targetAlphabet>, then the trailing characters of <sourceAlphabet>
14005+
# are removed from the input string.
14006+
#
14007+
# Because TRANSLATE only supports 1-to-1 character mappings, any entries with multi-character
14008+
# values must be handled by REPLACE instead. 1-character keys are always invalid.
14009+
single_char_pairs = {}
14010+
none_keys = set()
14011+
for key, value in table.items():
14012+
# Treat integers as unicode codepoints
14013+
if isinstance(key, int):
14014+
key = chr(key)
14015+
if isinstance(value, int):
14016+
value = chr(value)
14017+
if len(key) != 1:
14018+
# Mimic error from str.maketrans
14019+
raise ValueError(
14020+
f"Invalid mapping key '{key}'. String keys in translate table must be of length 1."
14021+
)
14022+
if value is not None and len(value) > 1:
14023+
raise NotImplementedError(
14024+
f"Invalid mapping value '{value}' for key '{key}'. Snowpark pandas currently only "
14025+
"supports unicode ordinals or 1-codepoint strings as values in str.translate mappings. "
14026+
"Consider using Series.str.replace to replace multiple characters."
14027+
)
14028+
if value is None or len(value) == 0:
14029+
none_keys.add(key)
14030+
else:
14031+
single_char_pairs[key] = value
14032+
source_alphabet = "".join(single_char_pairs.keys()) + "".join(none_keys)
14033+
target_alphabet = "".join(single_char_pairs.values())
14034+
return SnowflakeQueryCompiler(
14035+
self._modin_frame.apply_snowpark_function_to_data_columns(
14036+
lambda col_name: translate(
14037+
col(col_name),
14038+
pandas_lit(source_alphabet),
14039+
pandas_lit(target_alphabet),
14040+
)
14041+
)
14042+
)
1398414043

1398514044
def str_wrap(self, width: int, **kwargs: Any) -> None:
1398614045
ErrorMessage.method_not_implemented_error("wrap", "Series.str")

src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,74 @@ def normalize():
960960
pass
961961

962962
def translate():
963-
pass
963+
"""
964+
Map all characters in the string through the given mapping table.
965+
966+
Equivalent to standard :meth:`str.translate`.
967+
968+
Parameters
969+
----------
970+
table : dict
971+
Table is a mapping of Unicode ordinals to Unicode ordinals, strings, or
972+
None. Unmapped characters are left untouched.
973+
Characters mapped to None are deleted. :meth:`str.maketrans` is a
974+
helper function for making translation tables.
975+
976+
Returns
977+
-------
978+
Series
979+
980+
Examples
981+
--------
982+
>>> ser = pd.Series(["El niño", "Françoise"])
983+
>>> mytable = str.maketrans({'ñ': 'n', 'ç': 'c'})
984+
>>> ser.str.translate(mytable) # doctest: +NORMALIZE_WHITESPACE
985+
0 El nino
986+
1 Francoise
987+
dtype: object
988+
989+
Notes
990+
-----
991+
Snowpark pandas internally uses the Snowflake SQL `TRANSLATE` function to implement this
992+
operation. Since this function uses strings instead of unicode codepoints, it will accept
993+
mappings containing string keys that would be invalid in pandas.
994+
995+
The following example fails silently in vanilla pandas without `str.maketrans`:
996+
997+
>>> import pandas
998+
>>> pandas.Series("aaa").str.translate({"a": "A"})
999+
0 aaa
1000+
dtype: object
1001+
>>> pandas.Series("aaa").str.translate(str.maketrans({"a": "A"}))
1002+
0 AAA
1003+
dtype: object
1004+
1005+
The same code works in Snowpark pandas without `str.maketrans`:
1006+
1007+
>>> pd.Series("aaa").str.translate({"a": "A"})
1008+
0 AAA
1009+
dtype: object
1010+
>>> pd.Series("aaa").str.translate(str.maketrans({"a": "A"}))
1011+
0 AAA
1012+
dtype: object
1013+
1014+
Furthermore, due to restrictions in the underlying SQL, Snowpark pandas currently requires
1015+
all string values to be one unicode codepoint in length. To create replacements of multiple
1016+
characters, chain calls to `Series.str.replace` as needed.
1017+
1018+
Vanilla pandas code:
1019+
1020+
>>> import pandas
1021+
>>> pandas.Series("ab").str.translate(str.maketrans({"a": "A", "b": "BBB"}))
1022+
0 ABBB
1023+
dtype: object
1024+
1025+
Snowpark pandas equivalent:
1026+
1027+
>>> pd.Series("ab").str.translate({"a": "A"}).str.replace("b", "BBB")
1028+
0 ABBB
1029+
dtype: object
1030+
"""
9641031

9651032
def isalnum():
9661033
pass
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#
2+
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import modin.pandas as pd
6+
import pandas as native_pd
7+
import pytest
8+
9+
import snowflake.snowpark.modin.plugin # noqa: F401
10+
from tests.integ.modin.sql_counter import sql_count_checker
11+
from tests.integ.modin.utils import (
12+
assert_snowpark_pandas_equal_to_pandas,
13+
create_test_series,
14+
eval_snowpark_pandas_result,
15+
)
16+
17+
18+
@pytest.mark.parametrize(
19+
"data, table",
20+
[
21+
(
22+
# Simple 1-element mapping
23+
["aaaaa", "bbbaaa", "cafdsaf;lh"],
24+
str.maketrans("a", "b"),
25+
),
26+
(
27+
# Mapping with mixed str, unicode code points, and Nones
28+
["aaaaa", "fjkdsajk", "cjghgjqk", "yubikey"],
29+
str.maketrans(
30+
{ord("a"): "A", ord("f"): None, "y": "z", "k": None, ord("j"): ""}
31+
),
32+
),
33+
(
34+
# Mapping with special characters
35+
[
36+
"Peña",
37+
"Ordoñez",
38+
"Raúl",
39+
"Ibañez",
40+
"François",
41+
"øen",
42+
"2πr = τ",
43+
"München",
44+
],
45+
str.maketrans(
46+
{
47+
"ñ": "n",
48+
"ú": "u",
49+
"ç": "c",
50+
"ø": "o",
51+
"τ": "t",
52+
"π": "p",
53+
"ü": "u",
54+
}
55+
),
56+
),
57+
(
58+
# Mapping with compound emojis. Each item in the series renders as a single emoji,
59+
# but is actually 4 characters. Calling `len` on each element correctly returns 4.
60+
# https://unicode.org/emoji/charts/emoji-zwj-sequences.html
61+
# Inputs:
62+
# - "head shaking horizontally" = 1F642 + 200D + 2194 + FE0F
63+
# - "heart on fire" = 2764 + FE0F + 200D + 1F525
64+
# - "judge" = 1F9D1 + 200D + 2696 + FE0F
65+
# Outputs:
66+
# - "head shaking vertically" = 1F642 + 200D + 2195 + FE0F
67+
# - "mending heart" = 2764 + FE0F + 200D + 1FA79
68+
# - "health worker" = 1F91D1 + 200D + 2695 + FE0F
69+
["🙂‍↔️", "❤️‍🔥", "🧑‍⚖️"],
70+
{
71+
0x2194: 0x2195,
72+
0x1F525: 0x1FA79,
73+
0x2696: 0x2695,
74+
},
75+
),
76+
],
77+
)
78+
@sql_count_checker(query_count=1)
79+
def test_translate(data, table):
80+
eval_snowpark_pandas_result(
81+
*create_test_series(data), lambda ser: ser.str.translate(table)
82+
)
83+
84+
85+
@sql_count_checker(query_count=1)
86+
def test_translate_without_maketrans():
87+
# pandas requires all table keys to be unicode ordinal values, and does not know how to handle
88+
# string keys that were not converted to ordinals via `ord` or `str.maketrans`. Since Snowflake
89+
# SQL uses strings in its mappings, we accept string keys as well as ordinals.
90+
data = ["aaaaa", "fjkdsajk", "cjghgjqk", "yubikey"]
91+
table = {ord("a"): "A", ord("f"): None, "y": "z", "k": None}
92+
snow_ser = pd.Series(data)
93+
assert_snowpark_pandas_equal_to_pandas(
94+
snow_ser.str.translate(table),
95+
native_pd.Series(data).str.translate(str.maketrans(table)),
96+
)
97+
# Mappings for "y" and "k" are ignored if not passed through str.maketrans because they are
98+
# not unicode ordinals
99+
assert (
100+
not native_pd.Series(data)
101+
.str.translate(table)
102+
.equals(native_pd.Series(data).str.translate(str.maketrans(table)))
103+
)
104+
105+
106+
@pytest.mark.parametrize(
107+
"table, error",
108+
[
109+
({"😶‍🌫️": "a"}, ValueError), # This emoji key is secretly 4 code points
110+
({"aa": "a"}, ValueError), # Key is 2 chars
111+
# Mapping 1 char to multiple is valid in vanilla pandas, but we don't support this
112+
(
113+
{ord("a"): "😶‍🌫️"},
114+
NotImplementedError,
115+
), # This emoji value is secretly 4 code points
116+
({ord("a"): "aa"}, NotImplementedError), # Value is 2 chars
117+
],
118+
)
119+
@sql_count_checker(query_count=0)
120+
def test_translate_invalid_mappings(table, error):
121+
data = ["aaaaa", "fjkdsajk", "cjghgjqk", "yubikey"]
122+
# native pandas silently treats all of these cases as no-ops. However, since Snowflake SQL uses
123+
# strings as mappings instead of a dict construct, passing these arguments to the equivalent
124+
# SQL argument would either cause an inscrutable error or unexpected changes to the output series.
125+
snow_ser, native_ser = create_test_series(data)
126+
native_ser.str.translate(table)
127+
with pytest.raises(error):
128+
snow_ser.str.translate(table)

tests/unit/modin/test_series_strings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def test_str_cat_no_others(mock_str_register, mock_series):
5656
(lambda s: s.str.rindex("abc", start=1), "rindex"),
5757
(lambda s: s.str.swapcase(), "swapcase"),
5858
(lambda s: s.str.normalize("NFC"), "normalize"),
59-
(lambda s: s.str.translate(str.maketrans("a", "b")), "translate"),
6059
(lambda s: s.str.isalnum(), "isalnum"),
6160
(lambda s: s.str.isalpha(), "isalpha"),
6261
(lambda s: s.str.isnumeric(), "isnumeric"),

0 commit comments

Comments
 (0)