Skip to content

Commit 4ebc45a

Browse files
authored
make catalyst expressions sql compatible (#16)
1 parent 368b3f9 commit 4ebc45a

22 files changed

+587
-369
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ As such, you can pull in the current stable release by simply adding a library d
3131
For example, for an SBT project, simply add the following line to your `build.sbt`:
3232

3333
```
34-
libraryDependencies += "com.github.fqaiser94" %% "mse" % "0.2.2"
34+
libraryDependencies += "com.github.fqaiser94" %% "mse" % "0.2.4"
3535
```
3636

3737
For other types of projects (e.g. Maven, Gradle), see the installation instructions at this [link](https://search.maven.org/artifact/com.github.fqaiser94/mse_2.11).
@@ -43,7 +43,7 @@ You will also need to provide your PySpark application/s with the path to the MS
4343

4444
```bash
4545
pip install mse
46-
curl https://repo1.maven.org/maven2/com/github/fqaiser94/mse_2.11/0.2.2/mse_2.11-0.2.2.jar --output mse.jar
46+
curl https://repo1.maven.org/maven2/com/github/fqaiser94/mse_2.11/0.2.4/mse_2.11-0.2.4.jar --output mse.jar
4747
pyspark --jars mse.jar
4848
```
4949

@@ -345,6 +345,12 @@ arrayOfStructs.withColumn("array", transform($"array", elem => elem.dropFields("
345345
// +----------------+
346346
```
347347

348+
# SQL installation and usage
349+
350+
The underlying Catalyst Expressions are SQL compatible.
351+
Unfortunately, Spark only added public APIs for plugging in custom Catalyst Expressions into the FunctionRegistry in Spark 3.0.0
352+
(which is at the time of writing is still in preview). You can find a project with an example of how to do this [here](https://github.com/fqaiser94/mse-sql-example).
353+
348354
# Catalyst Optimization Rules
349355

350356
We also provide some Catalyst optimization rules that can be plugged into a Spark session to get even better performance. This is as simple as including the following two lines of code at the start of your Scala Spark program:
@@ -386,7 +392,7 @@ As you can see, the successive `add_fields` method calls have been collapsed int
386392

387393
Theoretically, this should improve performance but for the most part, you won't notice much difference unless you're doing some particularly intense struct manipulation and/or working with a particularly large dataset.
388394

389-
Unfortunately, to the best of our knowledge, there is currently no way to plug in custom Catalyst optimization rules using the Python APIs.
395+
Unfortunately, to the best of our knowledge, there is currently no way to plug in custom Catalyst optimization rules directly using the Python APIs.
390396

391397
# Questions/Thoughts/Concerns?
392398

python/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ For example:
2121

2222
```bash
2323
pip install mse
24-
curl https://repo1.maven.org/maven2/com/github/fqaiser94/mse_2.11/0.2.2/mse_2.11-0.2.2.jar --output mse.jar
24+
curl https://repo1.maven.org/maven2/com/github/fqaiser94/mse_2.11/0.2.4/mse_2.11-0.2.4.jar --output mse.jar
2525
pyspark --jars mse.jar
2626
```
2727

python/mse/tests/methods_tests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ def test_withField(self):
6969
with self.subTest("throw error if withField is called on a column that is not struct dataType"):
7070
self.assertRaisesRegex(
7171
AnalysisException,
72-
"struct should be struct data type. struct is integer",
72+
"Only struct is allowed to appear at first position, got: integer",
7373
lambda: non_struct_df.withColumn("a", col("a").withField("a", lit(2))).collect())
7474

7575
with self.subTest("throw error if null fieldName supplied"):
7676
self.assertRaisesRegex(
7777
AnalysisException,
78-
"fieldNames cannot contain null",
78+
"Only non-null foldable string expressions are allowed to appear at even position",
7979
lambda: struct_level1_df.withColumn("a", col("a").withField(None, lit(2))).collect())
8080

8181
with self.subTest("add new field to struct"):
@@ -194,19 +194,19 @@ def test_withFieldRenamed(self):
194194
with self.subTest("throw error if withFieldRenamed is called on a column that is not struct dataType"):
195195
self.assertRaisesRegex(
196196
AnalysisException,
197-
"struct should be struct data type. struct is integer",
197+
"Only struct is allowed to appear at first position, got: integer",
198198
lambda: non_struct_df.withColumn("a", col("a").withFieldRenamed("a", "z")))
199199

200200
with self.subTest("throw error if null existingFieldName supplied"):
201201
self.assertRaisesRegex(
202202
AnalysisException,
203-
"existingFieldName cannot be null",
203+
"Only non-null foldable string expressions are allowed to appear after first position.",
204204
lambda: struct_df.withColumn("a", col("a").withFieldRenamed(None, "z")))
205205

206206
with self.subTest("throw error if null newFieldName supplied"):
207207
self.assertRaisesRegex(
208208
AnalysisException,
209-
"newFieldName cannot be null",
209+
"Only non-null foldable string expressions are allowed to appear after first position.",
210210
lambda: struct_df.withColumn("a", col("a").withFieldRenamed("a", None)))
211211

212212
with self.subTest("rename field in struct"):
@@ -286,13 +286,13 @@ def test_dropFields(self):
286286
with self.subTest("throw error if withField is called on a column that is not struct dataType"):
287287
self.assertRaisesRegex(
288288
AnalysisException,
289-
"struct should be struct data type. struct is integer",
289+
"Only struct is allowed to appear at first position, got: integer",
290290
lambda: non_struct_df.withColumn("a", col("a").dropFields("a")))
291291

292292
with self.subTest("throw error if null fieldName supplied"):
293293
self.assertRaisesRegex(
294294
AnalysisException,
295-
"fieldNames cannot contain null",
295+
"Only non-null foldable string expressions are allowed after first position.",
296296
lambda: struct_df.withColumn("a", col("a").dropFields(None)))
297297

298298
with self.subTest("drop field in struct"):

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
setuptools.setup(
99
name='mse',
1010
packages=['mse'],
11-
version='0.1.3',
11+
version='0.1.4',
1212
license='Apache license 2.0',
1313
description='Make Structs Easy (MSE)',
1414
long_description=long_description,

python/testing/ReusedPySparkTestCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def conf(cls):
2828
Override this in subclasses to supply a more specific conf
2929
"""
3030

31-
return SparkConf().set("spark.driver.extraClassPath", "./mse.jar")
31+
return SparkConf().set("spark.jars", "./mse.jar")
3232

3333
@classmethod
3434
def setUpClass(cls):

src/main/scala/com/github/fqaiser94/mse/methods.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.github.fqaiser94.mse
22

3-
import org.apache.spark.sql.{Column, ColumnName}
43
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
5-
import org.apache.spark.sql.catalyst.expressions.{AddFields, DropFields, Expression, RenameFields}
4+
import org.apache.spark.sql.catalyst.expressions.{AddFields, DropFields, Expression, Literal, RenameFields}
65
import org.apache.spark.sql.functions._
6+
import org.apache.spark.sql.{Column, ColumnName}
77

88
object methods {
99

@@ -21,7 +21,7 @@ object methods {
2121
* @since 2.4.4
2222
*/
2323
def withField(fieldName: String, fieldValue: Column): Column = withExpr {
24-
AddFields(expr, fieldName, fieldValue.expr)
24+
AddFields(expr :: Literal(fieldName) :: fieldValue.expr :: Nil)
2525
}
2626

2727
/**
@@ -33,7 +33,7 @@ object methods {
3333
* @since 2.4.4
3434
*/
3535
def dropFields(fieldNames: String*): Column = withExpr {
36-
DropFields(expr, fieldNames: _*)
36+
DropFields(expr +: fieldNames.toList.map(Literal(_)))
3737
}
3838

3939
/**
@@ -45,7 +45,7 @@ object methods {
4545
* @since 2.4.4
4646
*/
4747
def withFieldRenamed(existingFieldName: String, newFieldName: String): Column = withExpr {
48-
RenameFields(expr, existingFieldName, newFieldName)
48+
RenameFields(expr :: Literal(existingFieldName) :: Literal(newFieldName) :: Nil)
4949
}
5050

5151
private def withExpr(newExpr: Expression): Column = new Column(newExpr)

src/main/scala/org/apache/spark/sql/catalyst/expressions/AddFields.scala

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,67 @@ import org.apache.spark.sql.catalyst.InternalRow
44
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
55
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
66
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode, FalseLiteral}
7-
import org.apache.spark.sql.types.{StructField, StructType}
7+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
8+
import org.apache.spark.unsafe.types.UTF8String
89

910
/**
10-
*
11-
* Adds/replaces fields in a struct.
12-
* Returns null if struct is null.
13-
* If there are multiple existing fields with the one of the fieldNames, they will all be replaced.
14-
*
15-
* @param struct : The struct to add fields to.
16-
* @param fieldNames : The names of the fieldExpressions to add to given struct.
17-
* @param fieldExpressions : The expressions to assign to each fieldName in fieldNames.
18-
*/
11+
* Adds/replaces fields in a struct.
12+
* Returns null if struct is null.
13+
* If multiple fields already exist with the one of the given fieldNames, they will all be replaced.
14+
*/
1915
// scalastyle:off line.size.limit
2016
@ExpressionDescription(
21-
usage = "_FUNC_(struct, fieldName, field) - Adds/replaces field in given struct.",
22-
examples =
23-
"""
17+
usage = "_FUNC_(struct, name1, val1, name2, val2, ...) - Adds/replaces fields in struct by name.",
18+
examples = """
2419
Examples:
25-
> SELECT _FUNC_({"a":1}, "b", 2);
26-
{"a":1,"b":2}
20+
> SELECT _FUNC_({"a":1}, "b", 2, "c", 3);
21+
{"a":1,"b":2,"c":3}
2722
""")
2823
// scalastyle:on line.size.limit
29-
case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressions: Seq[Expression]) extends Expression {
24+
case class AddFields(children: Seq[Expression]) extends Expression {
3025

31-
override def children: Seq[Expression] = struct +: fieldExpressions
26+
private lazy val struct: Expression = children.head
27+
private lazy val (nameExprs, valExprs) = children.drop(1).grouped(2).map {
28+
case Seq(name, value) => (name, value)
29+
}.toList.unzip
30+
private lazy val fieldNames = nameExprs.map(_.eval().asInstanceOf[UTF8String].toString)
31+
private lazy val pairs = fieldNames.zip(valExprs)
32+
33+
override def nullable: Boolean = struct.nullable
34+
35+
private lazy val ogStructType: StructType =
36+
struct.dataType.asInstanceOf[StructType]
3237

3338
override lazy val dataType: StructType = {
3439
val existingFields = ogStructType.fields.map { x => (x.name, x) }
35-
val addOrReplaceFields = pairs.map { case (fieldName, field) => (fieldName, StructField(fieldName, field.dataType, field.nullable)) }
40+
val addOrReplaceFields = pairs.map { case (fieldName, field) =>
41+
(fieldName, StructField(fieldName, field.dataType, field.nullable))
42+
}
3643
val newFields = loop(existingFields, addOrReplaceFields).map(_._2)
3744
StructType(newFields)
3845
}
3946

40-
override def nullable: Boolean = struct.nullable
41-
4247
override def checkInputDataTypes(): TypeCheckResult = {
48+
if (children.size % 2 == 0) {
49+
return TypeCheckResult.TypeCheckFailure(s"$prettyName expects an odd number of arguments.")
50+
}
51+
4352
val typeName = struct.dataType.typeName
44-
if (typeName != StructType(Nil).typeName)
53+
val expectedStructType = StructType(Nil).typeName
54+
if (typeName != expectedStructType) {
4555
return TypeCheckResult.TypeCheckFailure(
46-
s"struct should be struct data type. struct is $typeName")
47-
48-
if (fieldNames.contains(null))
49-
return TypeCheckResult.TypeCheckFailure("fieldNames cannot contain null")
56+
s"Only $expectedStructType is allowed to appear at first position, got: $typeName.")
57+
}
5058

51-
if (fieldExpressions.contains(null))
52-
return TypeCheckResult.TypeCheckFailure("fieldExpressions cannot contain null")
59+
if (nameExprs.contains(null) || nameExprs.exists(e => !(e.foldable && e.dataType == StringType))) {
60+
return TypeCheckResult.TypeCheckFailure(
61+
s"Only non-null foldable ${StringType.catalogString} expressions are allowed to appear at even position.")
62+
}
5363

54-
if (fieldNames.size != fieldExpressions.size)
55-
return TypeCheckResult.TypeCheckFailure("fieldNames and fieldExpressions cannot have different lengths")
64+
if (valExprs.contains(null)) {
65+
return TypeCheckResult.TypeCheckFailure(
66+
s"Only non-null expressions are allowed to appear at odd positions after first position.")
67+
}
5668

5769
TypeCheckResult.TypeCheckSuccess
5870
}
@@ -62,32 +74,36 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
6274
if (structValue == null) {
6375
null
6476
} else {
65-
val existingValues: Seq[(FieldName, Any)] = ogStructType.fieldNames.zip(structValue.asInstanceOf[InternalRow].toSeq(ogStructType))
66-
val addOrReplaceValues: Seq[(FieldName, Any)] = pairs.map { case (fieldName, expression) => (fieldName, expression.eval(input)) }
77+
val existingValues: Seq[(FieldName, Any)] =
78+
ogStructType.fieldNames.zip(structValue.asInstanceOf[InternalRow].toSeq(ogStructType))
79+
val addOrReplaceValues: Seq[(FieldName, Any)] =
80+
pairs.map { case (fieldName, expression) => (fieldName, expression.eval(input)) }
6781
val newValues = loop(existingValues, addOrReplaceValues).map(_._2)
6882
InternalRow.fromSeq(newValues)
6983
}
7084
}
7185

7286
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
7387
val structGen = struct.genCode(ctx)
74-
val addOrReplaceFieldsGens = fieldExpressions.map(_.genCode(ctx))
88+
val addOrReplaceFieldsGens = valExprs.map(_.genCode(ctx))
7589
val resultCode: String = {
7690
val structVar = structGen.value
7791
type NullCheck = String
7892
type NonNullValue = String
79-
val existingFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] = ogStructType.fields.zipWithIndex.map {
80-
case (structField, i) =>
81-
val nullCheck = s"$structVar.isNullAt($i)"
82-
val nonNullValue = CodeGenerator.getValue(structVar, structField.dataType, i.toString)
83-
(structField.name, (nullCheck, nonNullValue))
84-
}
85-
val addOrReplaceFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] = fieldNames.zip(addOrReplaceFieldsGens).map {
86-
case (fieldName, fieldExprCode) =>
87-
val nullCheck = fieldExprCode.isNull.code
88-
val nonNullValue = fieldExprCode.value.code
89-
(fieldName, (nullCheck, nonNullValue))
90-
}
93+
val existingFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] =
94+
ogStructType.fields.zipWithIndex.map {
95+
case (structField, i) =>
96+
val nullCheck = s"$structVar.isNullAt($i)"
97+
val nonNullValue = CodeGenerator.getValue(structVar, structField.dataType, i.toString)
98+
(structField.name, (nullCheck, nonNullValue))
99+
}
100+
val addOrReplaceFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] =
101+
fieldNames.zip(addOrReplaceFieldsGens).map {
102+
case (fieldName, fieldExprCode) =>
103+
val nullCheck = fieldExprCode.isNull.code
104+
val nonNullValue = fieldExprCode.value.code
105+
(fieldName, (nullCheck, nonNullValue))
106+
}
91107
val newFieldsCode = loop(existingFieldsCode, addOrReplaceFieldsCode)
92108
val rowClass = classOf[GenericInternalRow].getName
93109
val rowValuesVar = ctx.freshName("rowValues")
@@ -138,17 +154,14 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
138154

139155
override def prettyName: String = "add_fields"
140156

141-
private lazy val ogStructType: StructType =
142-
struct.dataType.asInstanceOf[StructType]
143-
144-
private val pairs = fieldNames.zip(fieldExpressions)
145-
146157
private type FieldName = String
147158

148159
/**
149-
* Recursively loops through addOrReplaceFields, adding or replacing fields by FieldName.
150-
*/
151-
private def loop[V](existingFields: Seq[(String, V)], addOrReplaceFields: Seq[(String, V)]): Seq[(String, V)] = {
160+
* Recursively loop through addOrReplaceFields, adding or replacing fields by FieldName.
161+
*/
162+
@scala.annotation.tailrec
163+
private def loop[V](existingFields: Seq[(String, V)],
164+
addOrReplaceFields: Seq[(String, V)]): Seq[(String, V)] = {
152165
if (addOrReplaceFields.nonEmpty) {
153166
val existingFieldNames = existingFields.map(_._1)
154167
val newField@(newFieldName, _) = addOrReplaceFields.head
@@ -172,6 +185,13 @@ case class AddFields(struct: Expression, fieldNames: Seq[String], fieldExpressio
172185
}
173186

174187
object AddFields {
188+
@deprecated("use AddFields(children: Seq[Expression]) constructor.", "0.2.4")
175189
def apply(struct: Expression, fieldName: String, fieldExpression: Expression): AddFields =
176-
AddFields(struct, Seq(fieldName), Seq(fieldExpression))
190+
AddFields(struct :: Literal(fieldName) :: fieldExpression :: Nil)
191+
192+
@deprecated("use AddFields(children: Seq[Expression]) constructor.", "0.2.4")
193+
def apply(struct: Expression, fieldNames: Seq[String], fieldExpressions: Seq[Expression]): AddFields = {
194+
val exprs = fieldNames.zip(fieldExpressions).flatMap { case (name, expr) => Seq(Literal(name), expr) }
195+
AddFields(struct +: exprs)
196+
}
177197
}

0 commit comments

Comments
 (0)