@@ -341,6 +341,12 @@ def __init__(
341341 def __repr__ (self ) -> str :
342342 return f"ArrayType({ repr (self .element_type ) if self .element_type else '' } )"
343343
344+ def _as_nested (self ) -> "ArrayType" :
345+ element_type = self .element_type
346+ if isinstance (element_type , (ArrayType , MapType , StructType )):
347+ element_type = element_type ._as_nested ()
348+ return ArrayType (element_type , self .structured )
349+
344350 def is_primitive (self ):
345351 return False
346352
@@ -391,6 +397,12 @@ def __repr__(self) -> str:
391397 def is_primitive (self ):
392398 return False
393399
400+ def _as_nested (self ) -> "MapType" :
401+ value_type = self .value_type
402+ if isinstance (value_type , (ArrayType , MapType , StructType )):
403+ value_type = value_type ._as_nested ()
404+ return MapType (self .key_type , value_type , self .structured )
405+
394406 @classmethod
395407 def from_json (cls , json_dict : Dict [str , Any ]) -> "MapType" :
396408 return MapType (
@@ -482,7 +494,6 @@ class ColumnIdentifier:
482494 """Represents a column identifier."""
483495
484496 def __init__ (self , normalized_name : str ) -> None :
485- self .raw_name = normalized_name
486497 self .normalized_name = quote_name (normalized_name )
487498 self ._original_name = normalized_name
488499
@@ -553,33 +564,41 @@ def __init__(
553564 column_identifier : Union [ColumnIdentifier , str ],
554565 datatype : DataType ,
555566 nullable : bool = True ,
567+ is_column : bool = True ,
556568 ) -> None :
557- self .column_identifier = (
558- ColumnIdentifier (column_identifier )
559- if isinstance (column_identifier , str )
560- else column_identifier
561- )
569+ self .name = column_identifier
570+ self .is_column = is_column
562571 self .datatype = datatype
563572 self .nullable = nullable
564573
565574 @property
566575 def name (self ) -> str :
567- """Returns the column name."""
568- return self .column_identifier .name
569-
570- @property
571- def raw_name (self ) -> str :
572- return self .column_identifier .raw_name
576+ return self .column_identifier .name if self .is_column else self ._name
573577
574578 @name .setter
575- def name (self , n : str ) -> None :
576- self .column_identifier = ColumnIdentifier (n )
579+ def name (self , n : Union [ColumnIdentifier , str ]) -> None :
580+ if isinstance (n , ColumnIdentifier ):
581+ self ._name = n .name
582+ self .column_identifier = n
583+ else :
584+ self ._name = n
585+ self .column_identifier = ColumnIdentifier (n )
586+
587+ def _as_nested (self ) -> "StructField" :
588+ datatype = self .datatype
589+ if isinstance (datatype , (ArrayType , MapType , StructType )):
590+ datatype = datatype ._as_nested ()
591+ # Nested StructFields do not follow column naming conventions
592+ return StructField (self ._name , datatype , self .nullable , is_column = False )
577593
578594 def __repr__ (self ) -> str :
579595 return f"StructField({ self .name !r} , { repr (self .datatype )} , nullable={ self .nullable } )"
580596
581597 def __eq__ (self , other ):
582- return isinstance (other , self .__class__ ) and self .__dict__ == other .__dict__
598+ return isinstance (other , self .__class__ ) and (
599+ (self .name , self .is_column , self .datatype , self .nullable )
600+ == (other .name , other .is_column , other .datatype , other .nullable )
601+ )
583602
584603 @classmethod
585604 def from_json (cls , json_dict : Dict [str , Any ]) -> "StructField" :
@@ -625,30 +644,40 @@ def __init__(
625644 self , fields : Optional [List ["StructField" ]] = None , structured = False
626645 ) -> None :
627646 self .structured = structured
628- if fields is None :
629- fields = []
630- self . fields = fields
647+ self . fields = []
648+ for field in fields :
649+ self . add ( field )
631650
632651 def add (
633652 self ,
634653 field : Union [str , ColumnIdentifier , "StructField" ],
635654 datatype : Optional [DataType ] = None ,
636655 nullable : Optional [bool ] = True ,
637656 ) -> "StructType" :
638- if isinstance (field , StructField ):
639- self .fields .append (field )
640- elif isinstance (field , (str , ColumnIdentifier )):
657+ if isinstance (field , (str , ColumnIdentifier )):
641658 if datatype is None :
642659 raise ValueError (
643660 "When field argument is str or ColumnIdentifier, datatype must not be None."
644661 )
645- self .fields .append (StructField (field , datatype , nullable ))
646- else :
662+ field = StructField (field , datatype , nullable )
663+ elif not isinstance (field , StructField ):
664+ __import__ ("pdb" ).set_trace ()
647665 raise ValueError (
648666 f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{ type (field )} '"
649667 )
668+
669+ # Nested data does not follow the same schema conventions as top level fields.
670+ if isinstance (field .datatype , (ArrayType , MapType , StructType )):
671+ field .datatype = field .datatype ._as_nested ()
672+
673+ self .fields .append (field )
650674 return self
651675
676+ def _as_nested (self ) -> "StructType" :
677+ return StructType (
678+ [field ._as_nested () for field in self .fields ], self .structured
679+ )
680+
652681 @classmethod
653682 def _from_attributes (cls , attributes : list ) -> "StructType" :
654683 return cls ([StructField (a .name , a .datatype , a .nullable ) for a in attributes ])
0 commit comments