Skip to content

SNOW-2334855: mock_get returns the datatype of the container, not of the returned value #3772

@pedro-villanueva-bcom

Description

@pedro-villanueva-bcom

Please answer these questions before submitting your issue. Thanks!

  1. What version of Python are you using?

Python 3.12.8

  1. What are the Snowpark Python and pandas versions in the environment?

pandas==2.3.1
snowflake-snowpark-python==1.35.0

  1. What did you do?
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark import Session
import pandas as pd

from snowflake.snowpark.functions import get
from snowflake.snowpark.mock import ColumnEmulator, ColumnType, patch

FIX = False

if FIX:

    @patch(get)
    def mock_get(column_expression: ColumnEmulator, value_expression: ColumnEmulator) -> ColumnEmulator:
        """Correctly mock get returning the right type"""

        def get(obj, key):
            try:
                if isinstance(obj, list) and key < len(obj):
                    return obj[key]
                elif isinstance(obj, dict):
                    return obj.get(key, None)
                else:
                    return None
            except KeyError:
                return None

        # pandas.Series.combine does not work here because it will not allow Nones in int columns
        result = []
        for exp, k in zip(column_expression, value_expression):
            result.append(get(exp, k))

        if column_expression.sf_type.datatype == T.ArrayType():
            return_type = column_expression.sf_type.datatype.element_type
        elif column_expression.sf_type.datatype == T.MapType():
            return_type = column_expression.sf_type.datatype.valueType
        else:
            raise ValueError(f"Only arrays and maps are supported, not {column_expression.sf_type.datatype}")

        return ColumnEmulator(result, sf_type=ColumnType(return_type, True), dtype=object)


spark = Session.builder.config("local_testing", True).getOrCreate()


def test_get():
    """Test the mocked get function"""
    data = pd.DataFrame(
        [
            [1, ["a", "b", "c"], 1],
            [2, ["d", "e"], 1],
            [3, ["f", "g", "h"], 1],
        ],
        columns=["ID", "COLUMN", "POS"],
    )
    df_array = spark.createDataFrame(data).select(
        "ID", F.col("COLUMN").cast("array<string>").alias("COLUMN"), "POS"
    )

    df_array.select(F.get(F.col("COLUMN"), F.col("POS")).alias("VALUE")).print_schema()

    data = pd.DataFrame(
        [
            [1, {"a": 1, "b": 3, "c": 5}, "a"],
            [2, {"d": 8, "e": 2}, "d"],
            [3, {"f": 5, "g": 7, "h": 1}, "f"],
        ],
        columns=["ID", "COLUMN", "KEY"],
    )
    df_map = spark.createDataFrame(data).select(
        "ID", F.col("COLUMN").cast("map<string,int>").alias("COLUMN"), "KEY"
    )
    df_map.select(F.get(F.col("COLUMN"), F.col("KEY")).alias("VALUE")).print_schema()


test_get()
  1. What did you expect to see?

With FIX=False, I get

root
 |-- "VALUE": ArrayType (nullable = True)
 |   |-- element: StringType()
root
 |-- "VALUE": MapType (nullable = True)
 |   |-- key: StringType()
 |   |-- value: StringType()

Which is wrong. The get function extracts elements from arrays and maps, so the returned column should have the data type of the contained object not the container.

With FIX=True. The expected and logical output is this:

root
 |-- "VALUE": StringType() (nullable = True)
root
 |-- "VALUE": StringType() (nullable = True)

I've fixed the mocked get function by returning the contained type:

if column_expression.sf_type.datatype == T.ArrayType():
    return_type = column_expression.sf_type.datatype.element_type
elif column_expression.sf_type.datatype == T.MapType():
    return_type = column_expression.sf_type.datatype.valueType
else:
    raise ValueError(f"Only arrays and maps are supported, not {column_expression.sf_type.datatype}")

return ColumnEmulator(result, sf_type=ColumnType(return_type, True), dtype=object)

I can open a PR to change the code if you think this is the right approach.

Metadata

Metadata

Labels

bugSomething isn't workinglocal testingLocal Testing issues/PRsstatus-triage_doneInitial triage done, will be further handled by the driver team

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions