-
Notifications
You must be signed in to change notification settings - Fork 142
Open
Labels
bugSomething isn't workingSomething isn't workinglocal testingLocal Testing issues/PRsLocal Testing issues/PRsstatus-triage_doneInitial triage done, will be further handled by the driver teamInitial triage done, will be further handled by the driver team
Description
Please answer these questions before submitting your issue. Thanks!
- What version of Python are you using?
Python 3.12.8
- What are the Snowpark Python and pandas versions in the environment?
pandas==2.3.1
snowflake-snowpark-python==1.35.0
- What did you do?
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark import Session
import pandas as pd
from snowflake.snowpark.functions import get
from snowflake.snowpark.mock import ColumnEmulator, ColumnType, patch
FIX = False
if FIX:
@patch(get)
def mock_get(column_expression: ColumnEmulator, value_expression: ColumnEmulator) -> ColumnEmulator:
"""Correctly mock get returning the right type"""
def get(obj, key):
try:
if isinstance(obj, list) and key < len(obj):
return obj[key]
elif isinstance(obj, dict):
return obj.get(key, None)
else:
return None
except KeyError:
return None
# pandas.Series.combine does not work here because it will not allow Nones in int columns
result = []
for exp, k in zip(column_expression, value_expression):
result.append(get(exp, k))
if column_expression.sf_type.datatype == T.ArrayType():
return_type = column_expression.sf_type.datatype.element_type
elif column_expression.sf_type.datatype == T.MapType():
return_type = column_expression.sf_type.datatype.valueType
else:
raise ValueError(f"Only arrays and maps are supported, not {column_expression.sf_type.datatype}")
return ColumnEmulator(result, sf_type=ColumnType(return_type, True), dtype=object)
spark = Session.builder.config("local_testing", True).getOrCreate()
def test_get():
"""Test the mocked get function"""
data = pd.DataFrame(
[
[1, ["a", "b", "c"], 1],
[2, ["d", "e"], 1],
[3, ["f", "g", "h"], 1],
],
columns=["ID", "COLUMN", "POS"],
)
df_array = spark.createDataFrame(data).select(
"ID", F.col("COLUMN").cast("array<string>").alias("COLUMN"), "POS"
)
df_array.select(F.get(F.col("COLUMN"), F.col("POS")).alias("VALUE")).print_schema()
data = pd.DataFrame(
[
[1, {"a": 1, "b": 3, "c": 5}, "a"],
[2, {"d": 8, "e": 2}, "d"],
[3, {"f": 5, "g": 7, "h": 1}, "f"],
],
columns=["ID", "COLUMN", "KEY"],
)
df_map = spark.createDataFrame(data).select(
"ID", F.col("COLUMN").cast("map<string,int>").alias("COLUMN"), "KEY"
)
df_map.select(F.get(F.col("COLUMN"), F.col("KEY")).alias("VALUE")).print_schema()
test_get()- What did you expect to see?
With FIX=False, I get
root
|-- "VALUE": ArrayType (nullable = True)
| |-- element: StringType()
root
|-- "VALUE": MapType (nullable = True)
| |-- key: StringType()
| |-- value: StringType()
Which is wrong. The get function extracts elements from arrays and maps, so the returned column should have the data type of the contained object not the container.
With FIX=True. The expected and logical output is this:
root
|-- "VALUE": StringType() (nullable = True)
root
|-- "VALUE": StringType() (nullable = True)
I've fixed the mocked get function by returning the contained type:
if column_expression.sf_type.datatype == T.ArrayType():
return_type = column_expression.sf_type.datatype.element_type
elif column_expression.sf_type.datatype == T.MapType():
return_type = column_expression.sf_type.datatype.valueType
else:
raise ValueError(f"Only arrays and maps are supported, not {column_expression.sf_type.datatype}")
return ColumnEmulator(result, sf_type=ColumnType(return_type, True), dtype=object)I can open a PR to change the code if you think this is the right approach.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinglocal testingLocal Testing issues/PRsLocal Testing issues/PRsstatus-triage_doneInitial triage done, will be further handled by the driver teamInitial triage done, will be further handled by the driver team