Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
- Added `save` method to `DataFrameWriter` to work in conjunction with `format`.
- Added support to read keyword arguments to `options` method for `DataFrameReader` and `DataFrameWriter`.
- Relaxed the cloudpickle dependency for Python 3.11 to simplify build requirements. However, for Python 3.11, `cloudpickle==2.2.1` remains the only supported version.
- Added support for mixed case field names in struct type columns.

#### Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.options import installed_pandas, pandas
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.types import (
LTZ,
NTZ,
Expand Down Expand Up @@ -157,9 +156,10 @@ def convert_metadata_to_sp_type(
return StructType(
[
StructField(
quote_name(field.name, keep_case=True),
field.name,
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
is_column=False,
)
for field in metadata.fields
],
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
StringType,
TimestampTimeZone,
TimestampType,
ArrayType,
MapType,
StructType,
)
from snowflake.snowpark.window import Window, WindowSpec

Expand Down Expand Up @@ -916,6 +919,9 @@ def _cast(
if isinstance(to, str):
to = type_string_to_type_object(to)

if isinstance(to, (ArrayType, MapType, StructType)):
to = to._as_nested()

if self._ast is None:
_emit_ast = False

Expand Down
69 changes: 51 additions & 18 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ def __init__(
def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"

def _as_nested(self) -> "ArrayType":
element_type = self.element_type
if isinstance(element_type, (ArrayType, MapType, StructType)):
element_type = element_type._as_nested()
return ArrayType(element_type, self.structured)

def is_primitive(self):
return False

Expand Down Expand Up @@ -391,6 +397,12 @@ def __repr__(self) -> str:
def is_primitive(self):
return False

def _as_nested(self) -> "MapType":
value_type = self.value_type
if isinstance(value_type, (ArrayType, MapType, StructType)):
value_type = value_type._as_nested()
return MapType(self.key_type, value_type, self.structured)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
return MapType(
Expand Down Expand Up @@ -552,29 +564,41 @@ def __init__(
column_identifier: Union[ColumnIdentifier, str],
datatype: DataType,
nullable: bool = True,
is_column: bool = True,
) -> None:
self.column_identifier = (
ColumnIdentifier(column_identifier)
if isinstance(column_identifier, str)
else column_identifier
)
self.name = column_identifier
self.is_column = is_column
self.datatype = datatype
self.nullable = nullable

@property
def name(self) -> str:
"""Returns the column name."""
return self.column_identifier.name
return self.column_identifier.name if self.is_column else self._name

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does is_column=True enforce type of self.column_identifier is ColumnIdentifier? can we add an assert for that in __init__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ uses the setter for self.name which ensures that self.column_identifier is always ColumnIdentifier.

@name.setter
def name(self, n: str) -> None:
self.column_identifier = ColumnIdentifier(n)
def name(self, n: Union[ColumnIdentifier, str]) -> None:
if isinstance(n, ColumnIdentifier):
self._name = n.name
self.column_identifier = n
else:
self._name = n
self.column_identifier = ColumnIdentifier(n)

def _as_nested(self) -> "StructField":
datatype = self.datatype
if isinstance(datatype, (ArrayType, MapType, StructType)):
datatype = datatype._as_nested()
# Nested StructFields do not follow column naming conventions
return StructField(self._name, datatype, self.nullable, is_column=False)

def __repr__(self) -> str:
return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})"

def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
return isinstance(other, self.__class__) and (
(self.name, self.is_column, self.datatype, self.nullable)
== (other.name, other.is_column, other.datatype, other.nullable)
)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
Expand Down Expand Up @@ -620,30 +644,39 @@ def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
) -> None:
self.structured = structured
if fields is None:
fields = []
self.fields = fields
self.fields = []
for field in fields or []:
self.add(field)

def add(
self,
field: Union[str, ColumnIdentifier, "StructField"],
datatype: Optional[DataType] = None,
nullable: Optional[bool] = True,
) -> "StructType":
if isinstance(field, StructField):
self.fields.append(field)
elif isinstance(field, (str, ColumnIdentifier)):
if isinstance(field, (str, ColumnIdentifier)):
if datatype is None:
raise ValueError(
"When field argument is str or ColumnIdentifier, datatype must not be None."
)
self.fields.append(StructField(field, datatype, nullable))
else:
field = StructField(field, datatype, nullable)
elif not isinstance(field, StructField):
raise ValueError(
f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'"
)

# Nested data does not follow the same schema conventions as top level fields.
if isinstance(field.datatype, (ArrayType, MapType, StructType)):
field.datatype = field.datatype._as_nested()

self.fields.append(field)
return self

def _as_nested(self) -> "StructType":
return StructType(
[field._as_nested() for field in self.fields], self.structured
)

@classmethod
def _from_attributes(cls, attributes: list) -> "StructType":
return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes])
Expand Down
40 changes: 20 additions & 20 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_STRUCTURE_DATAFRAME_QUERY = """
select
object_construct('k1', 1) :: map(varchar, int) as map,
object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj,
object_construct('A', 'foo', 'b', 0.05) :: object(A varchar, b float) as obj,
[1.0, 3.1, 4.5] :: array(float) as arr
"""

Expand All @@ -71,10 +71,10 @@ def _create_test_dataframe(s):
object_construct(lit("k1"), lit(1))
.cast(MapType(StringType(), IntegerType(), structured=True))
.alias("map"),
object_construct(lit("A"), lit("foo"), lit("B"), lit(0.05))
object_construct(lit("A"), lit("foo"), lit("b"), lit(0.05))
.cast(
StructType(
[StructField("A", StringType()), StructField("B", DoubleType())],
[StructField("A", StringType()), StructField("b", DoubleType())],
structured=True,
)
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def _create_test_dataframe(s):
StructType(
[
StructField("A", StringType(16777216), nullable=True),
StructField("B", DoubleType(), nullable=True),
StructField("b", DoubleType(), nullable=True),
],
structured=True,
),
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_structured_dtypes_select(structured_type_session, examples):
flattened_df = df.select(
df.map["k1"].alias("value1"),
df.obj["A"].alias("a"),
col("obj")["B"].alias("b"),
col("obj")["b"].alias("b"),
df.arr[0].alias("value2"),
df.arr[1].alias("value3"),
col("arr")[2].alias("value4"),
Expand All @@ -395,7 +395,7 @@ def test_structured_dtypes_select(structured_type_session, examples):
[
StructField("VALUE1", LongType(), nullable=True),
StructField("A", StringType(16777216), nullable=True),
StructField("B", DoubleType(), nullable=True),
StructField("b", DoubleType(), nullable=True),
StructField("VALUE2", DoubleType(), nullable=True),
StructField("VALUE3", DoubleType(), nullable=True),
StructField("VALUE4", DoubleType(), nullable=True),
Expand Down Expand Up @@ -424,12 +424,12 @@ def test_structured_dtypes_pandas(structured_type_session, structured_type_suppo
if structured_type_support:
assert (
pdf.to_json()
== '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","B":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}'
== '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","b":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}'
)
else:
assert (
pdf.to_json()
== '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"B\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}'
== '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"b\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}'
)


Expand Down Expand Up @@ -467,7 +467,7 @@ def test_structured_dtypes_iceberg(
)
assert save_ddl[0][0] == (
f"create or replace ICEBERG TABLE {table_name.upper()} (\n\t"
"MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, B DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n "
"MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, b DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n "
"EXTERNAL_VOLUME = 'PYTHON_CONNECTOR_ICEBERG_EXVOL'\n CATALOG = 'SNOWFLAKE'\n "
"BASE_LOCATION = 'python_connector_merge_gate/';"
)
Expand Down Expand Up @@ -524,27 +524,27 @@ def test_iceberg_nested_fields(
"NESTED_DATA",
StructType(
[
StructField('"camelCase"', StringType(), nullable=True),
StructField('"snake_case"', StringType(), nullable=True),
StructField('"PascalCase"', StringType(), nullable=True),
StructField("camelCase", StringType(), nullable=True),
StructField("snake_case", StringType(), nullable=True),
StructField("PascalCase", StringType(), nullable=True),
StructField(
'"nested_map"',
"nested_map",
MapType(
StringType(),
StructType(
[
StructField(
'"inner_camelCase"',
"inner_camelCase",
StringType(),
nullable=True,
),
StructField(
'"inner_snake_case"',
"inner_snake_case",
StringType(),
nullable=True,
),
StructField(
'"inner_PascalCase"',
"inner_PascalCase",
StringType(),
nullable=True,
),
Expand Down Expand Up @@ -733,8 +733,8 @@ def test_structured_dtypes_iceberg_create_from_values(
_, __, expected_schema = STRUCTURED_TYPES_EXAMPLES[True]
table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
data = [
({"x": 1}, {"A": "a", "B": 1}, [1, 1, 1]),
({"x": 2}, {"A": "b", "B": 2}, [2, 2, 2]),
({"x": 1}, {"A": "a", "b": 1}, [1, 1, 1]),
({"x": 2}, {"A": "b", "b": 2}, [2, 2, 2]),
]
try:
create_df = structured_type_session.create_dataframe(
Expand Down Expand Up @@ -945,8 +945,8 @@ def test_structured_type_print_schema(
" | |-- key: StringType()\n"
" | |-- value: ArrayType\n"
" | | |-- element: StructType\n"
' | | | |-- "FIELD1": StringType() (nullable = True)\n'
' | | | |-- "FIELD2": LongType() (nullable = True)\n'
' | | | |-- "Field1": StringType() (nullable = True)\n'
' | | | |-- "Field2": LongType() (nullable = True)\n'
)

# Test that depth works as expected
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def test_stored_procedure_with_structured_returns(
"OBJ",
StructType(
[
StructField('"a"', StringType(16777216), nullable=True),
StructField('"b"', DoubleType(), nullable=True),
StructField("a", StringType(16777216), nullable=True),
StructField("b", DoubleType(), nullable=True),
],
structured=True,
),
Expand Down
Loading