Skip to content

Commit f05c058

Browse files
committed
Refactor based on feedback
1 parent 0b79825 commit f05c058

File tree

4 files changed

+69
-34
lines changed

4 files changed

+69
-34
lines changed

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from snowflake.connector.constants import FIELD_ID_TO_NAME
3535
from snowflake.connector.cursor import ResultMetadata
3636
from snowflake.connector.options import installed_pandas, pandas
37-
from snowflake.snowpark._internal.utils import quote_name
3837
from snowflake.snowpark.types import (
3938
LTZ,
4039
NTZ,
@@ -157,9 +156,10 @@ def convert_metadata_to_sp_type(
157156
return StructType(
158157
[
159158
StructField(
160-
quote_name(field.name, keep_case=True),
159+
field.name,
161160
convert_metadata_to_sp_type(field, max_string_size),
162161
nullable=field.is_nullable,
162+
is_column=False,
163163
)
164164
for field in metadata.fields
165165
],
@@ -292,7 +292,7 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
292292
if isinstance(datatype, StructType):
293293
if datatype.structured:
294294
fields = ", ".join(
295-
f"{field.raw_name} {convert_sp_to_sf_type(field.datatype)}"
295+
f"{field.name} {convert_sp_to_sf_type(field.datatype)}"
296296
for field in datatype.fields
297297
)
298298
return f"OBJECT({fields})"

src/snowflake/snowpark/column.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@
9090
StringType,
9191
TimestampTimeZone,
9292
TimestampType,
93+
ArrayType,
94+
MapType,
95+
StructType,
9396
)
9497
from snowflake.snowpark.window import Window, WindowSpec
9598

@@ -916,6 +919,9 @@ def _cast(
916919
if isinstance(to, str):
917920
to = type_string_to_type_object(to)
918921

922+
if isinstance(to, (ArrayType, MapType, StructType)):
923+
to = to._as_nested()
924+
919925
if self._ast is None:
920926
_emit_ast = False
921927

src/snowflake/snowpark/types.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ def __init__(
341341
def __repr__(self) -> str:
342342
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"
343343

344+
def _as_nested(self) -> "ArrayType":
345+
element_type = self.element_type
346+
if isinstance(element_type, (ArrayType, MapType, StructType)):
347+
element_type = element_type._as_nested()
348+
return ArrayType(element_type, self.structured)
349+
344350
def is_primitive(self):
345351
return False
346352

@@ -391,6 +397,12 @@ def __repr__(self) -> str:
391397
def is_primitive(self):
392398
return False
393399

400+
def _as_nested(self) -> "MapType":
401+
value_type = self.value_type
402+
if isinstance(value_type, (ArrayType, MapType, StructType)):
403+
value_type = value_type._as_nested()
404+
return MapType(self.key_type, value_type, self.structured)
405+
394406
@classmethod
395407
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
396408
return MapType(
@@ -482,7 +494,6 @@ class ColumnIdentifier:
482494
"""Represents a column identifier."""
483495

484496
def __init__(self, normalized_name: str) -> None:
485-
self.raw_name = normalized_name
486497
self.normalized_name = quote_name(normalized_name)
487498
self._original_name = normalized_name
488499

@@ -553,33 +564,41 @@ def __init__(
553564
column_identifier: Union[ColumnIdentifier, str],
554565
datatype: DataType,
555566
nullable: bool = True,
567+
is_column: bool = True,
556568
) -> None:
557-
self.column_identifier = (
558-
ColumnIdentifier(column_identifier)
559-
if isinstance(column_identifier, str)
560-
else column_identifier
561-
)
569+
self.name = column_identifier
570+
self.is_column = is_column
562571
self.datatype = datatype
563572
self.nullable = nullable
564573

565574
@property
566575
def name(self) -> str:
567-
"""Returns the column name."""
568-
return self.column_identifier.name
569-
570-
@property
571-
def raw_name(self) -> str:
572-
return self.column_identifier.raw_name
576+
return self.column_identifier.name if self.is_column else self._name
573577

574578
@name.setter
575-
def name(self, n: str) -> None:
576-
self.column_identifier = ColumnIdentifier(n)
579+
def name(self, n: Union[ColumnIdentifier, str]) -> None:
580+
if isinstance(n, ColumnIdentifier):
581+
self._name = n.name
582+
self.column_identifier = n
583+
else:
584+
self._name = n
585+
self.column_identifier = ColumnIdentifier(n)
586+
587+
def _as_nested(self) -> "StructField":
588+
datatype = self.datatype
589+
if isinstance(datatype, (ArrayType, MapType, StructType)):
590+
datatype = datatype._as_nested()
591+
# Nested StructFields do not follow column naming conventions
592+
return StructField(self._name, datatype, self.nullable, is_column=False)
577593

578594
def __repr__(self) -> str:
579595
return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})"
580596

581597
def __eq__(self, other):
582-
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
598+
return isinstance(other, self.__class__) and (
599+
(self.name, self.is_column, self.datatype, self.nullable)
600+
== (other.name, other.is_column, other.datatype, other.nullable)
601+
)
583602

584603
@classmethod
585604
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
@@ -625,30 +644,40 @@ def __init__(
625644
self, fields: Optional[List["StructField"]] = None, structured=False
626645
) -> None:
627646
self.structured = structured
628-
if fields is None:
629-
fields = []
630-
self.fields = fields
647+
self.fields = []
648+
for field in fields:
649+
self.add(field)
631650

632651
def add(
633652
self,
634653
field: Union[str, ColumnIdentifier, "StructField"],
635654
datatype: Optional[DataType] = None,
636655
nullable: Optional[bool] = True,
637656
) -> "StructType":
638-
if isinstance(field, StructField):
639-
self.fields.append(field)
640-
elif isinstance(field, (str, ColumnIdentifier)):
657+
if isinstance(field, (str, ColumnIdentifier)):
641658
if datatype is None:
642659
raise ValueError(
643660
"When field argument is str or ColumnIdentifier, datatype must not be None."
644661
)
645-
self.fields.append(StructField(field, datatype, nullable))
646-
else:
662+
field = StructField(field, datatype, nullable)
663+
elif not isinstance(field, StructField):
664+
__import__("pdb").set_trace()
647665
raise ValueError(
648666
f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'"
649667
)
668+
669+
# Nested data does not follow the same schema conventions as top level fields.
670+
if isinstance(field.datatype, (ArrayType, MapType, StructType)):
671+
field.datatype = field.datatype._as_nested()
672+
673+
self.fields.append(field)
650674
return self
651675

676+
def _as_nested(self) -> "StructType":
677+
return StructType(
678+
[field._as_nested() for field in self.fields], self.structured
679+
)
680+
652681
@classmethod
653682
def _from_attributes(cls, attributes: list) -> "StructType":
654683
return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes])

tests/integ/scala/test_datatype_suite.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _create_test_dataframe(s):
106106
StructType(
107107
[
108108
StructField("A", StringType(16777216), nullable=True),
109-
StructField('"b"', DoubleType(), nullable=True),
109+
StructField("b", DoubleType(), nullable=True),
110110
],
111111
structured=True,
112112
),
@@ -524,27 +524,27 @@ def test_iceberg_nested_fields(
524524
"NESTED_DATA",
525525
StructType(
526526
[
527-
StructField('"camelCase"', StringType(), nullable=True),
528-
StructField('"snake_case"', StringType(), nullable=True),
529-
StructField('"PascalCase"', StringType(), nullable=True),
527+
StructField("camelCase", StringType(), nullable=True),
528+
StructField("snake_case", StringType(), nullable=True),
529+
StructField("PascalCase", StringType(), nullable=True),
530530
StructField(
531-
'"nested_map"',
531+
"nested_map",
532532
MapType(
533533
StringType(),
534534
StructType(
535535
[
536536
StructField(
537-
'"inner_camelCase"',
537+
"inner_camelCase",
538538
StringType(),
539539
nullable=True,
540540
),
541541
StructField(
542-
'"inner_snake_case"',
542+
"inner_snake_case",
543543
StringType(),
544544
nullable=True,
545545
),
546546
StructField(
547-
'"inner_PascalCase"',
547+
"inner_PascalCase",
548548
StringType(),
549549
nullable=True,
550550
),

0 commit comments

Comments
 (0)