Skip to content

Commit 0362c46

Browse files
authored
SNOW-1803811: Allow mixed-case field names for struct type columns (#2640)
1 parent 665bef1 commit 0362c46

File tree

7 files changed

+179
-100
lines changed

7 files changed

+179
-100
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- Added support for applying Snowpark Python function `snowflake_cortex_sentiment`.
2222
- Added support for `DataFrame.map`.
2323
- Added support for `DataFrame.from_dict` and `DataFrame.from_records`.
24+
- Added support for mixed case field names in struct type columns.
2425

2526
#### Improvements
2627
- Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible.

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_origin,
3131
)
3232

33+
import snowflake.snowpark.context as context
3334
import snowflake.snowpark.types # type: ignore
3435
from snowflake.connector.constants import FIELD_ID_TO_NAME
3536
from snowflake.connector.cursor import ResultMetadata
@@ -157,9 +158,12 @@ def convert_metadata_to_sp_type(
157158
return StructType(
158159
[
159160
StructField(
160-
quote_name(field.name, keep_case=True),
161+
field.name
162+
if context._should_use_structured_type_semantics
163+
else quote_name(field.name, keep_case=True),
161164
convert_metadata_to_sp_type(field, max_string_size),
162165
nullable=field.is_nullable,
166+
_is_column=False,
163167
)
164168
for field in metadata.fields
165169
],

src/snowflake/snowpark/column.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@
9191
StringType,
9292
TimestampTimeZone,
9393
TimestampType,
94+
ArrayType,
95+
MapType,
96+
StructType,
9497
)
9598
from snowflake.snowpark.window import Window, WindowSpec
9699

@@ -917,6 +920,9 @@ def _cast(
917920
if isinstance(to, str):
918921
to = type_string_to_type_object(to)
919922

923+
if isinstance(to, (ArrayType, MapType, StructType)):
924+
to = to._as_nested()
925+
920926
if self._ast is None:
921927
_emit_ast = False
922928

src/snowflake/snowpark/context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
_should_continue_registration: Optional[Callable[..., bool]] = None
2222

2323

24+
# Global flag that determines if structured type semantics should be used
25+
_should_use_structured_type_semantics = False
26+
27+
2428
def get_active_session() -> "snowflake.snowpark.Session":
2529
"""Returns the current active Snowpark session.
2630

src/snowflake/snowpark/types.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# Use correct version from here:
1818
from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name
19+
import snowflake.snowpark.context as context
1920

2021
# TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer.
2122
# The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong.
@@ -341,6 +342,14 @@ def __init__(
341342
def __repr__(self) -> str:
342343
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"
343344

345+
def _as_nested(self) -> "ArrayType":
346+
if not context._should_use_structured_type_semantics:
347+
return self
348+
element_type = self.element_type
349+
if isinstance(element_type, (ArrayType, MapType, StructType)):
350+
element_type = element_type._as_nested()
351+
return ArrayType(element_type, self.structured)
352+
344353
def is_primitive(self):
345354
return False
346355

@@ -391,6 +400,14 @@ def __repr__(self) -> str:
391400
def is_primitive(self):
392401
return False
393402

403+
def _as_nested(self) -> "MapType":
404+
if not context._should_use_structured_type_semantics:
405+
return self
406+
value_type = self.value_type
407+
if isinstance(value_type, (ArrayType, MapType, StructType)):
408+
value_type = value_type._as_nested()
409+
return MapType(self.key_type, value_type, self.structured)
410+
394411
@classmethod
395412
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
396413
return MapType(
@@ -552,29 +569,46 @@ def __init__(
552569
column_identifier: Union[ColumnIdentifier, str],
553570
datatype: DataType,
554571
nullable: bool = True,
572+
_is_column: bool = True,
555573
) -> None:
556-
self.column_identifier = (
557-
ColumnIdentifier(column_identifier)
558-
if isinstance(column_identifier, str)
559-
else column_identifier
560-
)
574+
self.name = column_identifier
575+
self._is_column = _is_column
561576
self.datatype = datatype
562577
self.nullable = nullable
563578

564579
@property
565580
def name(self) -> str:
566-
"""Returns the column name."""
567-
return self.column_identifier.name
581+
if self._is_column or not context._should_use_structured_type_semantics:
582+
return self.column_identifier.name
583+
else:
584+
return self._name
568585

569586
@name.setter
570-
def name(self, n: str) -> None:
571-
self.column_identifier = ColumnIdentifier(n)
587+
def name(self, n: Union[ColumnIdentifier, str]) -> None:
588+
if isinstance(n, ColumnIdentifier):
589+
self._name = n.name
590+
self.column_identifier = n
591+
else:
592+
self._name = n
593+
self.column_identifier = ColumnIdentifier(n)
594+
595+
def _as_nested(self) -> "StructField":
596+
if not context._should_use_structured_type_semantics:
597+
return self
598+
datatype = self.datatype
599+
if isinstance(datatype, (ArrayType, MapType, StructType)):
600+
datatype = datatype._as_nested()
601+
# Nested StructFields do not follow column naming conventions
602+
return StructField(self._name, datatype, self.nullable, _is_column=False)
572603

573604
def __repr__(self) -> str:
574605
return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})"
575606

576607
def __eq__(self, other):
577-
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
608+
return isinstance(other, self.__class__) and (
609+
(self.name, self._is_column, self.datatype, self.nullable)
610+
== (other.name, other._is_column, other.datatype, other.nullable)
611+
)
578612

579613
@classmethod
580614
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
@@ -620,30 +654,41 @@ def __init__(
620654
self, fields: Optional[List["StructField"]] = None, structured=False
621655
) -> None:
622656
self.structured = structured
623-
if fields is None:
624-
fields = []
625-
self.fields = fields
657+
self.fields = []
658+
for field in fields or []:
659+
self.add(field)
626660

627661
def add(
628662
self,
629663
field: Union[str, ColumnIdentifier, "StructField"],
630664
datatype: Optional[DataType] = None,
631665
nullable: Optional[bool] = True,
632666
) -> "StructType":
633-
if isinstance(field, StructField):
634-
self.fields.append(field)
635-
elif isinstance(field, (str, ColumnIdentifier)):
667+
if isinstance(field, (str, ColumnIdentifier)):
636668
if datatype is None:
637669
raise ValueError(
638670
"When field argument is str or ColumnIdentifier, datatype must not be None."
639671
)
640-
self.fields.append(StructField(field, datatype, nullable))
641-
else:
672+
field = StructField(field, datatype, nullable)
673+
elif not isinstance(field, StructField):
642674
raise ValueError(
643675
f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'"
644676
)
677+
678+
# Nested data does not follow the same schema conventions as top level fields.
679+
if isinstance(field.datatype, (ArrayType, MapType, StructType)):
680+
field.datatype = field.datatype._as_nested()
681+
682+
self.fields.append(field)
645683
return self
646684

685+
def _as_nested(self) -> "StructType":
686+
if not context._should_use_structured_type_semantics:
687+
return self
688+
return StructType(
689+
[field._as_nested() for field in self.fields], self.structured
690+
)
691+
647692
@classmethod
648693
def _from_attributes(cls, attributes: list) -> "StructType":
649694
return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes])

0 commit comments

Comments
 (0)