|
16 | 16 |
|
17 | 17 | # Use correct version from here: |
18 | 18 | from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name |
| 19 | +import snowflake.snowpark.context as context |
19 | 20 |
|
20 | 21 | # TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer. |
21 | 22 | # The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong. |
@@ -341,6 +342,14 @@ def __init__( |
341 | 342 | def __repr__(self) -> str: |
342 | 343 | return f"ArrayType({repr(self.element_type) if self.element_type else ''})" |
343 | 344 |
|
| 345 | + def _as_nested(self) -> "ArrayType": |
| 346 | + if not context._should_use_structured_type_semantics: |
| 347 | + return self |
| 348 | + element_type = self.element_type |
| 349 | + if isinstance(element_type, (ArrayType, MapType, StructType)): |
| 350 | + element_type = element_type._as_nested() |
| 351 | + return ArrayType(element_type, self.structured) |
| 352 | + |
344 | 353 | def is_primitive(self): |
345 | 354 | return False |
346 | 355 |
|
@@ -391,6 +400,14 @@ def __repr__(self) -> str: |
391 | 400 | def is_primitive(self): |
392 | 401 | return False |
393 | 402 |
|
| 403 | + def _as_nested(self) -> "MapType": |
| 404 | + if not context._should_use_structured_type_semantics: |
| 405 | + return self |
| 406 | + value_type = self.value_type |
| 407 | + if isinstance(value_type, (ArrayType, MapType, StructType)): |
| 408 | + value_type = value_type._as_nested() |
| 409 | + return MapType(self.key_type, value_type, self.structured) |
| 410 | + |
394 | 411 | @classmethod |
395 | 412 | def from_json(cls, json_dict: Dict[str, Any]) -> "MapType": |
396 | 413 | return MapType( |
@@ -552,29 +569,46 @@ def __init__( |
552 | 569 | column_identifier: Union[ColumnIdentifier, str], |
553 | 570 | datatype: DataType, |
554 | 571 | nullable: bool = True, |
| 572 | + _is_column: bool = True, |
555 | 573 | ) -> None: |
556 | | - self.column_identifier = ( |
557 | | - ColumnIdentifier(column_identifier) |
558 | | - if isinstance(column_identifier, str) |
559 | | - else column_identifier |
560 | | - ) |
| 574 | + self.name = column_identifier |
| 575 | + self._is_column = _is_column |
561 | 576 | self.datatype = datatype |
562 | 577 | self.nullable = nullable |
563 | 578 |
|
564 | 579 | @property |
565 | 580 | def name(self) -> str: |
566 | | - """Returns the column name.""" |
567 | | - return self.column_identifier.name |
| 581 | + if self._is_column or not context._should_use_structured_type_semantics: |
| 582 | + return self.column_identifier.name |
| 583 | + else: |
| 584 | + return self._name |
568 | 585 |
|
569 | 586 | @name.setter |
570 | | - def name(self, n: str) -> None: |
571 | | - self.column_identifier = ColumnIdentifier(n) |
| 587 | + def name(self, n: Union[ColumnIdentifier, str]) -> None: |
| 588 | + if isinstance(n, ColumnIdentifier): |
| 589 | + self._name = n.name |
| 590 | + self.column_identifier = n |
| 591 | + else: |
| 592 | + self._name = n |
| 593 | + self.column_identifier = ColumnIdentifier(n) |
| 594 | + |
| 595 | + def _as_nested(self) -> "StructField": |
| 596 | + if not context._should_use_structured_type_semantics: |
| 597 | + return self |
| 598 | + datatype = self.datatype |
| 599 | + if isinstance(datatype, (ArrayType, MapType, StructType)): |
| 600 | + datatype = datatype._as_nested() |
| 601 | + # Nested StructFields do not follow column naming conventions |
| 602 | + return StructField(self._name, datatype, self.nullable, _is_column=False) |
572 | 603 |
|
573 | 604 | def __repr__(self) -> str: |
574 | 605 | return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})" |
575 | 606 |
|
576 | 607 | def __eq__(self, other): |
577 | | - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ |
| 608 | + return isinstance(other, self.__class__) and ( |
| 609 | + (self.name, self._is_column, self.datatype, self.nullable) |
| 610 | + == (other.name, other._is_column, other.datatype, other.nullable) |
| 611 | + ) |
578 | 612 |
|
579 | 613 | @classmethod |
580 | 614 | def from_json(cls, json_dict: Dict[str, Any]) -> "StructField": |
@@ -620,30 +654,41 @@ def __init__( |
620 | 654 | self, fields: Optional[List["StructField"]] = None, structured=False |
621 | 655 | ) -> None: |
622 | 656 | self.structured = structured |
623 | | - if fields is None: |
624 | | - fields = [] |
625 | | - self.fields = fields |
| 657 | + self.fields = [] |
| 658 | + for field in fields or []: |
| 659 | + self.add(field) |
626 | 660 |
|
627 | 661 | def add( |
628 | 662 | self, |
629 | 663 | field: Union[str, ColumnIdentifier, "StructField"], |
630 | 664 | datatype: Optional[DataType] = None, |
631 | 665 | nullable: Optional[bool] = True, |
632 | 666 | ) -> "StructType": |
633 | | - if isinstance(field, StructField): |
634 | | - self.fields.append(field) |
635 | | - elif isinstance(field, (str, ColumnIdentifier)): |
| 667 | + if isinstance(field, (str, ColumnIdentifier)): |
636 | 668 | if datatype is None: |
637 | 669 | raise ValueError( |
638 | 670 | "When field argument is str or ColumnIdentifier, datatype must not be None." |
639 | 671 | ) |
640 | | - self.fields.append(StructField(field, datatype, nullable)) |
641 | | - else: |
| 672 | + field = StructField(field, datatype, nullable) |
| 673 | + elif not isinstance(field, StructField): |
642 | 674 | raise ValueError( |
643 | 675 | f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'" |
644 | 676 | ) |
| 677 | + |
| 678 | + # Nested data does not follow the same schema conventions as top level fields. |
| 679 | + if isinstance(field.datatype, (ArrayType, MapType, StructType)): |
| 680 | + field.datatype = field.datatype._as_nested() |
| 681 | + |
| 682 | + self.fields.append(field) |
645 | 683 | return self |
646 | 684 |
|
| 685 | + def _as_nested(self) -> "StructType": |
| 686 | + if not context._should_use_structured_type_semantics: |
| 687 | + return self |
| 688 | + return StructType( |
| 689 | + [field._as_nested() for field in self.fields], self.structured |
| 690 | + ) |
| 691 | + |
647 | 692 | @classmethod |
648 | 693 | def _from_attributes(cls, attributes: list) -> "StructType": |
649 | 694 | return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes]) |
|
0 commit comments