@@ -4,55 +4,67 @@ import org.apache.spark.sql.catalyst.InternalRow
44import  org .apache .spark .sql .catalyst .analysis .TypeCheckResult 
55import  org .apache .spark .sql .catalyst .expressions .codegen .Block ._ 
66import  org .apache .spark .sql .catalyst .expressions .codegen .{CodeGenerator , CodegenContext , ExprCode , FalseLiteral }
7- import  org .apache .spark .sql .types .{StructField , StructType }
7+ import  org .apache .spark .sql .types .{StringType , StructField , StructType }
8+ import  org .apache .spark .unsafe .types .UTF8String 
89
910/** 
10-   * 
11-   * Adds/replaces fields in a struct. 
12-   * Returns null if struct is null. 
13-   * If there are multiple existing fields with the one of the fieldNames, they will all be replaced. 
14-   * 
15-   * @param  struct            : The struct to add fields to. 
16-   * @param  fieldNames        : The names of the fieldExpressions to add to given struct. 
17-   * @param  fieldExpressions  : The expressions to assign to each fieldName in fieldNames. 
18-   */  
11+  * Adds/replaces fields in a struct. 
12+  * Returns null if struct is null. 
13+  * If multiple fields already exist with the one of the given fieldNames, they will all be replaced. 
14+  */  
1915//  scalastyle:off line.size.limit
2016@ ExpressionDescription (
21-   usage =  " _FUNC_(struct, fieldName, field) - Adds/replaces field in given struct." 
22-   examples = 
23-     """ 
17+   usage =  " _FUNC_(struct, name1, val1, name2, val2, ...) - Adds/replaces fields in struct by name." 
18+   examples =  """ 
2419    Examples: 
25-       > SELECT _FUNC_({"a":1}, "b", 2); 
26-        {"a":1,"b":2} 
20+       > SELECT _FUNC_({"a":1}, "b", 2, "c", 3 ); 
21+        {"a":1,"b":2,"c":3 } 
2722  """  )
2823//  scalastyle:on line.size.limit
29- case  class  AddFields (struct :  Expression ,  fieldNames :  Seq [ String ],  fieldExpressions : Seq [Expression ]) extends  Expression  {
24+ case  class  AddFields (children : Seq [Expression ]) extends  Expression  {
3025
31-   override  def  children :  Seq [Expression ] =  struct +:  fieldExpressions
26+   private  lazy  val  struct :  Expression  =  children.head
27+   private  lazy  val  (nameExprs, valExprs) =  children.drop(1 ).grouped(2 ).map {
28+     case  Seq (name, value) =>  (name, value)
29+   }.toList.unzip
30+   private  lazy  val  fieldNames  =  nameExprs.map(_.eval().asInstanceOf [UTF8String ].toString)
31+   private  lazy  val  pairs  =  fieldNames.zip(valExprs)
32+ 
33+   override  def  nullable :  Boolean  =  struct.nullable
34+ 
35+   private  lazy  val  ogStructType :  StructType  = 
36+     struct.dataType.asInstanceOf [StructType ]
3237
3338  override  lazy  val  dataType :  StructType  =  {
3439    val  existingFields  =  ogStructType.fields.map { x =>  (x.name, x) }
35-     val  addOrReplaceFields  =  pairs.map { case  (fieldName, field) =>  (fieldName, StructField (fieldName, field.dataType, field.nullable)) }
40+     val  addOrReplaceFields  =  pairs.map { case  (fieldName, field) => 
41+       (fieldName, StructField (fieldName, field.dataType, field.nullable))
42+     }
3643    val  newFields  =  loop(existingFields, addOrReplaceFields).map(_._2)
3744    StructType (newFields)
3845  }
3946
40-   override  def  nullable :  Boolean  =  struct.nullable
41- 
4247  override  def  checkInputDataTypes ():  TypeCheckResult  =  {
48+     if  (children.size %  2  ==  0 ) {
49+       return  TypeCheckResult .TypeCheckFailure (s " $prettyName expects an odd number of arguments. " )
50+     }
51+ 
4352    val  typeName  =  struct.dataType.typeName
44-     if  (typeName !=  StructType (Nil ).typeName)
53+     val  expectedStructType  =  StructType (Nil ).typeName
54+     if  (typeName !=  expectedStructType) {
4555      return  TypeCheckResult .TypeCheckFailure (
46-         s " struct should be struct data type. struct is  $typeName" )
47- 
48-     if  (fieldNames.contains(null ))
49-       return  TypeCheckResult .TypeCheckFailure (" fieldNames cannot contain null" 
56+         s " Only  $expectedStructType is allowed to appear at first position, got:  $typeName. " )
57+     }
5058
51-     if  (fieldExpressions.contains(null ))
52-       return  TypeCheckResult .TypeCheckFailure (" fieldExpressions cannot contain null" 
59+     if  (nameExprs.contains(null ) ||  nameExprs.exists(e =>  ! (e.foldable &&  e.dataType ==  StringType ))) {
60+       return  TypeCheckResult .TypeCheckFailure (
61+         s " Only non-null foldable  ${StringType .catalogString} expressions are allowed to appear at even position. " )
62+     }
5363
54-     if  (fieldNames.size !=  fieldExpressions.size)
55-       return  TypeCheckResult .TypeCheckFailure (" fieldNames and fieldExpressions cannot have different lengths" 
64+     if  (valExprs.contains(null )) {
65+       return  TypeCheckResult .TypeCheckFailure (
66+         s " Only non-null expressions are allowed to appear at odd positions after first position. " )
67+     }
5668
5769    TypeCheckResult .TypeCheckSuccess 
5870  }
@@ -62,32 +74,36 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
6274    if  (structValue ==  null ) {
6375      null 
6476    } else  {
65-       val  existingValues :  Seq [(FieldName , Any )] =  ogStructType.fieldNames.zip(structValue.asInstanceOf [InternalRow ].toSeq(ogStructType))
66-       val  addOrReplaceValues :  Seq [(FieldName , Any )] =  pairs.map { case  (fieldName, expression) =>  (fieldName, expression.eval(input)) }
77+       val  existingValues :  Seq [(FieldName , Any )] = 
78+         ogStructType.fieldNames.zip(structValue.asInstanceOf [InternalRow ].toSeq(ogStructType))
79+       val  addOrReplaceValues :  Seq [(FieldName , Any )] = 
80+         pairs.map { case  (fieldName, expression) =>  (fieldName, expression.eval(input)) }
6781      val  newValues  =  loop(existingValues, addOrReplaceValues).map(_._2)
6882      InternalRow .fromSeq(newValues)
6983    }
7084  }
7185
7286  override  def  doGenCode (ctx : CodegenContext , ev : ExprCode ):  ExprCode  =  {
7387    val  structGen  =  struct.genCode(ctx)
74-     val  addOrReplaceFieldsGens  =  fieldExpressions .map(_.genCode(ctx))
88+     val  addOrReplaceFieldsGens  =  valExprs .map(_.genCode(ctx))
7589    val  resultCode :  String  =  {
7690      val  structVar  =  structGen.value
7791      type  NullCheck  =  String 
7892      type  NonNullValue  =  String 
79-       val  existingFieldsCode :  Seq [(FieldName , (NullCheck , NonNullValue ))] =  ogStructType.fields.zipWithIndex.map {
80-         case  (structField, i) => 
81-           val  nullCheck  =  s " $structVar.isNullAt( $i) " 
82-           val  nonNullValue  =  CodeGenerator .getValue(structVar, structField.dataType, i.toString)
83-           (structField.name, (nullCheck, nonNullValue))
84-       }
85-       val  addOrReplaceFieldsCode :  Seq [(FieldName , (NullCheck , NonNullValue ))] =  fieldNames.zip(addOrReplaceFieldsGens).map {
86-         case  (fieldName, fieldExprCode) => 
87-           val  nullCheck  =  fieldExprCode.isNull.code
88-           val  nonNullValue  =  fieldExprCode.value.code
89-           (fieldName, (nullCheck, nonNullValue))
90-       }
93+       val  existingFieldsCode :  Seq [(FieldName , (NullCheck , NonNullValue ))] = 
94+         ogStructType.fields.zipWithIndex.map {
95+           case  (structField, i) => 
96+             val  nullCheck  =  s " $structVar.isNullAt( $i) " 
97+             val  nonNullValue  =  CodeGenerator .getValue(structVar, structField.dataType, i.toString)
98+             (structField.name, (nullCheck, nonNullValue))
99+         }
100+       val  addOrReplaceFieldsCode :  Seq [(FieldName , (NullCheck , NonNullValue ))] = 
101+         fieldNames.zip(addOrReplaceFieldsGens).map {
102+           case  (fieldName, fieldExprCode) => 
103+             val  nullCheck  =  fieldExprCode.isNull.code
104+             val  nonNullValue  =  fieldExprCode.value.code
105+             (fieldName, (nullCheck, nonNullValue))
106+         }
91107      val  newFieldsCode  =  loop(existingFieldsCode, addOrReplaceFieldsCode)
92108      val  rowClass  =  classOf [GenericInternalRow ].getName
93109      val  rowValuesVar  =  ctx.freshName(" rowValues" 
@@ -138,17 +154,14 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
138154
139155  override  def  prettyName :  String  =  " add_fields" 
140156
141-   private  lazy  val  ogStructType :  StructType  = 
142-     struct.dataType.asInstanceOf [StructType ]
143- 
144-   private  val  pairs  =  fieldNames.zip(fieldExpressions)
145- 
146157  private  type  FieldName  =  String 
147158
148159  /**  
149-     * Recursively loops through addOrReplaceFields, adding or replacing fields by FieldName. 
150-     */  
151-   private  def  loop [V ](existingFields : Seq [(String , V )], addOrReplaceFields : Seq [(String , V )]):  Seq [(String , V )] =  {
160+    * Recursively loop through addOrReplaceFields, adding or replacing fields by FieldName. 
161+    */  
162+   @ scala.annotation.tailrec
163+   private  def  loop [V ](existingFields : Seq [(String , V )],
164+                       addOrReplaceFields : Seq [(String , V )]):  Seq [(String , V )] =  {
152165    if  (addOrReplaceFields.nonEmpty) {
153166      val  existingFieldNames  =  existingFields.map(_._1)
154167      val  newField @ (newFieldName, _) =  addOrReplaceFields.head
@@ -172,6 +185,13 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
172185}
173186
174187object  AddFields  {
188+   @ deprecated(" use AddFields(children: Seq[Expression]) constructor." " 0.2.4" 
175189  def  apply (struct : Expression , fieldName : String , fieldExpression : Expression ):  AddFields  = 
176-     AddFields (struct, Seq (fieldName), Seq (fieldExpression))
190+     AddFields (struct ::  Literal (fieldName) ::  fieldExpression ::  Nil )
191+ 
192+   @ deprecated(" use AddFields(children: Seq[Expression]) constructor." " 0.2.4" 
193+   def  apply (struct : Expression , fieldNames : Seq [String ], fieldExpressions : Seq [Expression ]):  AddFields  =  {
194+     val  exprs  =  fieldNames.zip(fieldExpressions).flatMap { case  (name, expr) =>  Seq (Literal (name), expr) }
195+     AddFields (struct +:  exprs)
196+   }
177197}
0 commit comments